.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/mlearning/regression_model/plot_linear_regression.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_mlearning_regression_model_plot_linear_regression.py: Linear regression ================= A :class:`.LinearRegressor` is a linear regression model based on `scikit-learn `__. .. seealso:: You can find more information about building linear models with scikit-learn on `this page `__. .. GENERATED FROM PYTHON SOURCE LINES 32-43 .. code-block:: Python from __future__ import annotations from matplotlib import pyplot as plt from numpy import array from gemseo import create_design_space from gemseo import create_discipline from gemseo import sample_disciplines from gemseo.mlearning import create_regression_model .. GENERATED FROM PYTHON SOURCE LINES 44-49 Problem ------- In this example, we represent the function :math:`f(x)=(6x-2)^2\sin(12x-4)` :cite:`forrester2008` by the :class:`.AnalyticDiscipline` .. GENERATED FROM PYTHON SOURCE LINES 49-54 .. code-block:: Python discipline = create_discipline( "AnalyticDiscipline", name="f", expressions={"y": "(6*x-2)**2*sin(12*x-4)"}, ) .. GENERATED FROM PYTHON SOURCE LINES 55-56 and seek to approximate it over the input space .. GENERATED FROM PYTHON SOURCE LINES 56-59 .. code-block:: Python input_space = create_design_space() input_space.add_variable("x", lower_bound=0.0, upper_bound=1.0) .. GENERATED FROM PYTHON SOURCE LINES 60-62 To do this, we create a training dataset with 6 equispaced points: .. GENERATED FROM PYTHON SOURCE LINES 62-66 .. code-block:: Python training_dataset = sample_disciplines( [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=6 ) .. rst-class:: sphx-glr-script-out .. code-block:: none INFO - 16:22:19: *** Start Sampling execution *** INFO - 16:22:19: Sampling INFO - 16:22:19: Disciplines: f INFO - 16:22:19: MDO formulation: MDF INFO - 16:22:19: Running the algorithm PYDOE_FULLFACT: INFO - 16:22:19: 17%|█▋ | 1/6 [00:00<00:00, 646.97 it/sec] INFO - 16:22:19: 33%|███▎ | 2/6 [00:00<00:00, 1055.97 it/sec] INFO - 16:22:19: 50%|█████ | 3/6 [00:00<00:00, 1387.92 it/sec] INFO - 16:22:19: 67%|██████▋ | 4/6 [00:00<00:00, 1642.73 it/sec] INFO - 16:22:19: 83%|████████▎ | 5/6 [00:00<00:00, 1871.29 it/sec] INFO - 16:22:19: 100%|██████████| 6/6 [00:00<00:00, 2013.75 it/sec] INFO - 16:22:19: *** End Sampling execution *** .. GENERATED FROM PYTHON SOURCE LINES 67-73 Basics ------ Training ~~~~~~~~ Then, we train a linear regression model from these samples: .. GENERATED FROM PYTHON SOURCE LINES 73-76 .. code-block:: Python model = create_regression_model("LinearRegressor", training_dataset) model.learn() .. GENERATED FROM PYTHON SOURCE LINES 77-81 Prediction ~~~~~~~~~~ Once it is built, we can predict the output value of :math:`f` at a new input point: .. GENERATED FROM PYTHON SOURCE LINES 81-85 .. code-block:: Python input_value = {"x": array([0.65])} output_value = model.predict(input_value) output_value .. rst-class:: sphx-glr-script-out .. code-block:: none {'y': array([3.29457456])} .. GENERATED FROM PYTHON SOURCE LINES 86-87 as well as its Jacobian value: .. GENERATED FROM PYTHON SOURCE LINES 87-90 .. code-block:: Python jacobian_value = model.predict_jacobian(input_value) jacobian_value .. rst-class:: sphx-glr-script-out .. code-block:: none {'y': {'x': array([[7.26002643]])}} .. GENERATED FROM PYTHON SOURCE LINES 91-95 Plotting ~~~~~~~~ Of course, you can see that the linear model is no good at all here: .. GENERATED FROM PYTHON SOURCE LINES 95-107 .. code-block:: Python test_dataset = sample_disciplines( [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=100 ) input_data = test_dataset.get_view(variable_names=model.input_names).to_numpy() reference_output_data = test_dataset.get_view(variable_names="y").to_numpy().ravel() predicted_output_data = model.predict(input_data).ravel() plt.plot(input_data.ravel(), reference_output_data, label="Reference") plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics") plt.grid() plt.legend() plt.show() .. image-sg:: /examples/mlearning/regression_model/images/sphx_glr_plot_linear_regression_001.png :alt: plot linear regression :srcset: /examples/mlearning/regression_model/images/sphx_glr_plot_linear_regression_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none INFO - 16:22:19: *** Start Sampling execution *** INFO - 16:22:19: Sampling INFO - 16:22:19: Disciplines: f INFO - 16:22:19: MDO formulation: MDF INFO - 16:22:19: Running the algorithm PYDOE_FULLFACT: INFO - 16:22:19: 1%| | 1/100 [00:00<00:00, 3728.27 it/sec] INFO - 16:22:19: 2%|▏ | 2/100 [00:00<00:00, 3623.59 it/sec] INFO - 16:22:19: 3%|▎ | 3/100 [00:00<00:00, 3686.76 it/sec] INFO - 16:22:19: 4%|▍ | 4/100 [00:00<00:00, 3812.14 it/sec] INFO - 16:22:19: 5%|▌ | 5/100 [00:00<00:00, 3807.47 it/sec] INFO - 16:22:19: 6%|▌ | 6/100 [00:00<00:00, 3892.02 it/sec] INFO - 16:22:19: 7%|▋ | 7/100 [00:00<00:00, 3965.98 it/sec] INFO - 16:22:19: 8%|▊ | 8/100 [00:00<00:00, 4022.35 it/sec] INFO - 16:22:19: 9%|▉ | 9/100 [00:00<00:00, 4018.82 it/sec] INFO - 16:22:19: 10%|█ | 10/100 [00:00<00:00, 4053.25 it/sec] INFO - 16:22:19: 11%|█ | 11/100 [00:00<00:00, 4086.57 it/sec] INFO - 16:22:19: 12%|█▏ | 12/100 [00:00<00:00, 4114.42 it/sec] INFO - 16:22:19: 13%|█▎ | 13/100 [00:00<00:00, 4099.39 it/sec] INFO - 16:22:19: 14%|█▍ | 14/100 [00:00<00:00, 4112.35 it/sec] INFO - 16:22:19: 15%|█▌ | 15/100 [00:00<00:00, 4136.39 it/sec] INFO - 16:22:19: 16%|█▌ | 16/100 [00:00<00:00, 4162.56 it/sec] INFO - 16:22:19: 17%|█▋ | 17/100 [00:00<00:00, 4188.14 it/sec] INFO - 16:22:19: 18%|█▊ | 18/100 [00:00<00:00, 4184.54 it/sec] INFO - 16:22:19: 19%|█▉ | 19/100 [00:00<00:00, 4199.17 it/sec] INFO - 16:22:19: 20%|██ | 20/100 [00:00<00:00, 4206.71 it/sec] INFO - 16:22:19: 21%|██ | 21/100 [00:00<00:00, 4221.85 it/sec] INFO - 16:22:19: 22%|██▏ | 22/100 [00:00<00:00, 4214.42 it/sec] INFO - 16:22:19: 23%|██▎ | 23/100 [00:00<00:00, 4220.73 it/sec] INFO - 16:22:19: 24%|██▍ | 24/100 [00:00<00:00, 4233.64 it/sec] INFO - 16:22:19: 25%|██▌ | 25/100 [00:00<00:00, 4246.62 it/sec] INFO - 16:22:19: 26%|██▌ | 26/100 [00:00<00:00, 4244.92 it/sec] INFO - 16:22:19: 27%|██▋ | 27/100 [00:00<00:00, 4251.46 it/sec] INFO - 16:22:19: 28%|██▊ | 28/100 [00:00<00:00, 4260.18 it/sec] INFO - 16:22:19: 29%|██▉ | 29/100 [00:00<00:00, 4272.99 it/sec] INFO - 16:22:19: 30%|███ | 30/100 [00:00<00:00, 4283.98 it/sec] INFO - 16:22:19: 31%|███ | 31/100 [00:00<00:00, 4280.04 it/sec] INFO - 16:22:19: 32%|███▏ | 32/100 [00:00<00:00, 4288.38 it/sec] INFO - 16:22:19: 33%|███▎ | 33/100 [00:00<00:00, 4298.24 it/sec] INFO - 16:22:19: 34%|███▍ | 34/100 [00:00<00:00, 4306.66 it/sec] INFO - 16:22:19: 35%|███▌ | 35/100 [00:00<00:00, 4304.88 it/sec] INFO - 16:22:19: 36%|███▌ | 36/100 [00:00<00:00, 4309.46 it/sec] INFO - 16:22:19: 37%|███▋ | 37/100 [00:00<00:00, 4318.01 it/sec] INFO - 16:22:19: 38%|███▊ | 38/100 [00:00<00:00, 4320.27 it/sec] INFO - 16:22:19: 39%|███▉ | 39/100 [00:00<00:00, 4326.31 it/sec] INFO - 16:22:19: 40%|████ | 40/100 [00:00<00:00, 4323.13 it/sec] INFO - 16:22:19: 41%|████ | 41/100 [00:00<00:00, 4329.36 it/sec] INFO - 16:22:19: 42%|████▏ | 42/100 [00:00<00:00, 4336.80 it/sec] INFO - 16:22:19: 43%|████▎ | 43/100 [00:00<00:00, 4341.41 it/sec] INFO - 16:22:19: 44%|████▍ | 44/100 [00:00<00:00, 4341.32 it/sec] INFO - 16:22:19: 45%|████▌ | 45/100 [00:00<00:00, 4346.33 it/sec] INFO - 16:22:19: 46%|████▌ | 46/100 [00:00<00:00, 4352.70 it/sec] INFO - 16:22:19: 47%|████▋ | 47/100 [00:00<00:00, 4358.83 it/sec] INFO - 16:22:19: 48%|████▊ | 48/100 [00:00<00:00, 4363.67 it/sec] INFO - 16:22:19: 49%|████▉ | 49/100 [00:00<00:00, 4360.26 it/sec] INFO - 16:22:19: 50%|█████ | 50/100 [00:00<00:00, 4365.07 it/sec] INFO - 16:22:19: 51%|█████ | 51/100 [00:00<00:00, 4370.49 it/sec] INFO - 16:22:19: 52%|█████▏ | 52/100 [00:00<00:00, 4375.90 it/sec] INFO - 16:22:19: 53%|█████▎ | 53/100 [00:00<00:00, 4373.11 it/sec] INFO - 16:22:19: 54%|█████▍ | 54/100 [00:00<00:00, 4373.62 it/sec] INFO - 16:22:19: 55%|█████▌ | 55/100 [00:00<00:00, 4377.27 it/sec] INFO - 16:22:19: 56%|█████▌ | 56/100 [00:00<00:00, 4378.27 it/sec] INFO - 16:22:19: 57%|█████▋ | 57/100 [00:00<00:00, 4383.65 it/sec] INFO - 16:22:19: 58%|█████▊ | 58/100 [00:00<00:00, 4381.42 it/sec] INFO - 16:22:19: 59%|█████▉ | 59/100 [00:00<00:00, 4385.09 it/sec] INFO - 16:22:19: 60%|██████ | 60/100 [00:00<00:00, 4390.49 it/sec] INFO - 16:22:19: 61%|██████ | 61/100 [00:00<00:00, 4356.79 it/sec] INFO - 16:22:19: 62%|██████▏ | 62/100 [00:00<00:00, 4352.39 it/sec] INFO - 16:22:19: 63%|██████▎ | 63/100 [00:00<00:00, 4355.89 it/sec] INFO - 16:22:19: 64%|██████▍ | 64/100 [00:00<00:00, 4359.98 it/sec] INFO - 16:22:19: 65%|██████▌ | 65/100 [00:00<00:00, 4364.17 it/sec] INFO - 16:22:19: 66%|██████▌ | 66/100 [00:00<00:00, 4363.90 it/sec] INFO - 16:22:19: 67%|██████▋ | 67/100 [00:00<00:00, 4365.88 it/sec] INFO - 16:22:19: 68%|██████▊ | 68/100 [00:00<00:00, 4368.80 it/sec] INFO - 16:22:19: 69%|██████▉ | 69/100 [00:00<00:00, 4371.97 it/sec] INFO - 16:22:19: 70%|███████ | 70/100 [00:00<00:00, 4376.10 it/sec] INFO - 16:22:19: 71%|███████ | 71/100 [00:00<00:00, 4373.62 it/sec] INFO - 16:22:19: 72%|███████▏ | 72/100 [00:00<00:00, 4376.22 it/sec] INFO - 16:22:19: 73%|███████▎ | 73/100 [00:00<00:00, 4376.06 it/sec] INFO - 16:22:19: 74%|███████▍ | 74/100 [00:00<00:00, 4378.43 it/sec] INFO - 16:22:19: 75%|███████▌ | 75/100 [00:00<00:00, 4376.91 it/sec] INFO - 16:22:19: 76%|███████▌ | 76/100 [00:00<00:00, 4379.21 it/sec] INFO - 16:22:19: 77%|███████▋ | 77/100 [00:00<00:00, 4382.29 it/sec] INFO - 16:22:19: 78%|███████▊ | 78/100 [00:00<00:00, 4385.76 it/sec] INFO - 16:22:19: 79%|███████▉ | 79/100 [00:00<00:00, 4389.09 it/sec] INFO - 16:22:19: 80%|████████ | 80/100 [00:00<00:00, 4386.43 it/sec] INFO - 16:22:19: 81%|████████ | 81/100 [00:00<00:00, 4388.54 it/sec] INFO - 16:22:19: 82%|████████▏ | 82/100 [00:00<00:00, 4391.77 it/sec] INFO - 16:22:19: 83%|████████▎ | 83/100 [00:00<00:00, 4395.21 it/sec] INFO - 16:22:19: 84%|████████▍ | 84/100 [00:00<00:00, 4393.86 it/sec] INFO - 16:22:19: 85%|████████▌ | 85/100 [00:00<00:00, 4395.03 it/sec] INFO - 16:22:19: 86%|████████▌ | 86/100 [00:00<00:00, 4397.56 it/sec] INFO - 16:22:19: 87%|████████▋ | 87/100 [00:00<00:00, 4399.14 it/sec] INFO - 16:22:19: 88%|████████▊ | 88/100 [00:00<00:00, 4401.00 it/sec] INFO - 16:22:19: 89%|████████▉ | 89/100 [00:00<00:00, 4398.62 it/sec] INFO - 16:22:19: 90%|█████████ | 90/100 [00:00<00:00, 4399.88 it/sec] INFO - 16:22:19: 91%|█████████ | 91/100 [00:00<00:00, 4400.19 it/sec] INFO - 16:22:19: 92%|█████████▏| 92/100 [00:00<00:00, 4399.00 it/sec] INFO - 16:22:19: 93%|█████████▎| 93/100 [00:00<00:00, 4396.94 it/sec] INFO - 16:22:19: 94%|█████████▍| 94/100 [00:00<00:00, 4397.33 it/sec] INFO - 16:22:19: 95%|█████████▌| 95/100 [00:00<00:00, 4399.41 it/sec] INFO - 16:22:19: 96%|█████████▌| 96/100 [00:00<00:00, 4401.06 it/sec] INFO - 16:22:19: 97%|█████████▋| 97/100 [00:00<00:00, 4402.54 it/sec] INFO - 16:22:19: 98%|█████████▊| 98/100 [00:00<00:00, 4399.84 it/sec] INFO - 16:22:19: 99%|█████████▉| 99/100 [00:00<00:00, 4401.53 it/sec] INFO - 16:22:19: 100%|██████████| 100/100 [00:00<00:00, 4344.27 it/sec] INFO - 16:22:19: *** End Sampling execution *** .. GENERATED FROM PYTHON SOURCE LINES 108-119 Settings -------- The :class:`.LinearRegressor` has many options defined in the :class:`.LinearRegressor_Settings` Pydantic model. Intercept ~~~~~~~~~~ By default, the linear model is of the form :math:`a_0+a_1x_1+\ldots+a_dx_d`. You can set the option ``fit_intercept`` to ``False`` if you want a linear model of the form :math:`a_1x_1+\ldots+a_dx_d`: .. GENERATED FROM PYTHON SOURCE LINES 119-123 .. code-block:: Python model = create_regression_model( "LinearRegressor", training_dataset, fit_intercept=False, transformer={} ) model.learn() .. GENERATED FROM PYTHON SOURCE LINES 124-130 .. warning:: This notion applies in the space of transformed variables. This is the reason why we removed the default transformers by setting ``transformer`` to ``{}``. We can see the impact of this option in the following visualization: .. GENERATED FROM PYTHON SOURCE LINES 130-138 .. code-block:: Python predicted_output_data_ = model.predict(input_data).ravel() plt.plot(input_data.ravel(), reference_output_data, label="Reference") plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics") plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - No intercept") plt.grid() plt.legend() plt.show() .. image-sg:: /examples/mlearning/regression_model/images/sphx_glr_plot_linear_regression_002.png :alt: plot linear regression :srcset: /examples/mlearning/regression_model/images/sphx_glr_plot_linear_regression_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 139-157 Regularization ~~~~~~~~~~~~~~ When the number of samples is small relative to the input dimension, regularization techniques can save you from overfitting (a model that is very good at learning but bad at generalization). The ``penalty_level`` option is a positive real number defining the degree of regularization (default: no regularization). By default, the regularization technique is the ridge penalty (l2 regularization). The technique can be replaced by the lasso penalty (l1 regularization) by setting the ``l2_penalty_ratio`` option to ``0.0``. When ``l2_penalty_ratio`` is between 0 and 1, the regularization technique is the elastic net penalty, *i.e.* a linear combination of ridge and lasso penalty parametrized by this ``l2_penalty_ratio``. For example, we can use the ridge penalty with a level of 1.2 .. GENERATED FROM PYTHON SOURCE LINES 157-166 .. code-block:: Python model = create_regression_model("LinearRegressor", training_dataset, penalty_level=1.2) model.learn() predicted_output_data_ = model.predict(input_data).ravel() plt.plot(input_data.ravel(), reference_output_data, label="Reference") plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics") plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - Ridge(1.2)") plt.grid() plt.legend() plt.show() .. image-sg:: /examples/mlearning/regression_model/images/sphx_glr_plot_linear_regression_003.png :alt: plot linear regression :srcset: /examples/mlearning/regression_model/images/sphx_glr_plot_linear_regression_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 167-173 We can see that the coefficient of the linear model is lower due to the penalty. .. note:: In the case of a model with many inputs, we could have used the lasso penalty and seen that some coefficients would have been set to zero. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.199 seconds) .. _sphx_glr_download_examples_mlearning_regression_model_plot_linear_regression.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_linear_regression.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_linear_regression.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_linear_regression.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_