.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/mlearning/regression_model/plot_random_forest_regression.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_mlearning_regression_model_plot_random_forest_regression.py: Random forest ============= A :class:`.RandomForestRegressor` is a random forest model based on `scikit-learn `__. .. GENERATED FROM PYTHON SOURCE LINES 28-42 .. code-block:: Python from __future__ import annotations from matplotlib import pyplot as plt from numpy import array from gemseo import configure_logger from gemseo import create_design_space from gemseo import create_discipline from gemseo import sample_disciplines from gemseo.mlearning import create_regression_model configure_logger() .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 43-48 Problem ------- In this example, we represent the function :math:`f(x)=(6x-2)^2\sin(12x-4)` :cite:`forrester2008` by the :class:`.AnalyticDiscipline` .. GENERATED FROM PYTHON SOURCE LINES 48-53 .. code-block:: Python discipline = create_discipline( "AnalyticDiscipline", name="f", expressions={"y": "(6*x-2)**2*sin(12*x-4)"}, ) .. GENERATED FROM PYTHON SOURCE LINES 54-55 and seek to approximate it over the input space .. GENERATED FROM PYTHON SOURCE LINES 55-58 .. code-block:: Python input_space = create_design_space() input_space.add_variable("x", lower_bound=0.0, upper_bound=1.0) .. GENERATED FROM PYTHON SOURCE LINES 59-61 To do this, we create a training dataset with 6 equispaced points: .. GENERATED FROM PYTHON SOURCE LINES 61-65 .. code-block:: Python training_dataset = sample_disciplines( [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=6 ) .. rst-class:: sphx-glr-script-out .. code-block:: none WARNING - 11:42:46: No coupling in MDA, switching chain_linearize to True. INFO - 11:42:46: *** Start Sampling execution *** INFO - 11:42:46: Sampling INFO - 11:42:46: Disciplines: f INFO - 11:42:46: MDO formulation: MDF INFO - 11:42:46: Running the algorithm PYDOE_FULLFACT: INFO - 11:42:46: 17%|█▋ | 1/6 [00:00<00:00, 561.71 it/sec] INFO - 11:42:46: 33%|███▎ | 2/6 [00:00<00:00, 888.81 it/sec] INFO - 11:42:46: 50%|█████ | 3/6 [00:00<00:00, 1138.52 it/sec] INFO - 11:42:46: 67%|██████▋ | 4/6 [00:00<00:00, 1343.04 it/sec] INFO - 11:42:46: 83%|████████▎ | 5/6 [00:00<00:00, 1493.80 it/sec] INFO - 11:42:46: 100%|██████████| 6/6 [00:00<00:00, 1630.76 it/sec] INFO - 11:42:46: *** End Sampling execution (time: 0:00:00.004787) *** .. GENERATED FROM PYTHON SOURCE LINES 66-72 Basics ------ Training ~~~~~~~~ Then, we train an random forest regression model from these samples: .. GENERATED FROM PYTHON SOURCE LINES 72-75 .. code-block:: Python model = create_regression_model("RandomForestRegressor", training_dataset) model.learn() .. GENERATED FROM PYTHON SOURCE LINES 76-80 Prediction ~~~~~~~~~~ Once it is built, we can predict the output value of :math:`f` at a new input point: .. GENERATED FROM PYTHON SOURCE LINES 80-84 .. code-block:: Python input_value = {"x": array([0.65])} output_value = model.predict(input_value) output_value .. rst-class:: sphx-glr-script-out .. code-block:: none {'y': array([-0.88837697])} .. GENERATED FROM PYTHON SOURCE LINES 85-86 but cannot predict its Jacobian value: .. GENERATED FROM PYTHON SOURCE LINES 86-91 .. code-block:: Python try: model.predict_jacobian(input_value) except NotImplementedError: print("The derivatives are not available for RandomForestRegressor.") .. rst-class:: sphx-glr-script-out .. code-block:: none The derivatives are not available for RandomForestRegressor. .. GENERATED FROM PYTHON SOURCE LINES 92-96 Plotting ~~~~~~~~ You can see that the random forest model is pretty good on the left, but bad on the right: .. GENERATED FROM PYTHON SOURCE LINES 96-108 .. code-block:: Python test_dataset = sample_disciplines( [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=100 ) input_data = test_dataset.get_view(variable_names=model.input_names).to_numpy() reference_output_data = test_dataset.get_view(variable_names="y").to_numpy().ravel() predicted_output_data = model.predict(input_data).ravel() plt.plot(input_data.ravel(), reference_output_data, label="Reference") plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics") plt.grid() plt.legend() plt.show() .. image-sg:: /examples/mlearning/regression_model/images/sphx_glr_plot_random_forest_regression_001.png :alt: plot random forest regression :srcset: /examples/mlearning/regression_model/images/sphx_glr_plot_random_forest_regression_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none WARNING - 11:42:46: No coupling in MDA, switching chain_linearize to True. INFO - 11:42:46: *** Start Sampling execution *** INFO - 11:42:46: Sampling INFO - 11:42:46: Disciplines: f INFO - 11:42:46: MDO formulation: MDF INFO - 11:42:46: Running the algorithm PYDOE_FULLFACT: INFO - 11:42:46: 1%| | 1/100 [00:00<00:00, 2549.73 it/sec] INFO - 11:42:46: 2%|▏ | 2/100 [00:00<00:00, 2465.79 it/sec] INFO - 11:42:46: 3%|▎ | 3/100 [00:00<00:00, 2529.74 it/sec] INFO - 11:42:46: 4%|▍ | 4/100 [00:00<00:00, 2623.08 it/sec] INFO - 11:42:46: 5%|▌ | 5/100 [00:00<00:00, 2701.82 it/sec] INFO - 11:42:46: 6%|▌ | 6/100 [00:00<00:00, 2710.67 it/sec] INFO - 11:42:46: 7%|▋ | 7/100 [00:00<00:00, 2751.39 it/sec] INFO - 11:42:46: 8%|▊ | 8/100 [00:00<00:00, 2811.43 it/sec] INFO - 11:42:46: 9%|▉ | 9/100 [00:00<00:00, 2839.96 it/sec] INFO - 11:42:46: 10%|█ | 10/100 [00:00<00:00, 2875.57 it/sec] INFO - 11:42:46: 11%|█ | 11/100 [00:00<00:00, 2906.84 it/sec] INFO - 11:42:46: 12%|█▏ | 12/100 [00:00<00:00, 2925.41 it/sec] INFO - 11:42:46: 13%|█▎ | 13/100 [00:00<00:00, 2949.26 it/sec] INFO - 11:42:46: 14%|█▍ | 14/100 [00:00<00:00, 2971.52 it/sec] INFO - 11:42:46: 15%|█▌ | 15/100 [00:00<00:00, 2999.07 it/sec] INFO - 11:42:46: 16%|█▌ | 16/100 [00:00<00:00, 3006.94 it/sec] INFO - 11:42:46: 17%|█▋ | 17/100 [00:00<00:00, 3021.45 it/sec] INFO - 11:42:46: 18%|█▊ | 18/100 [00:00<00:00, 3040.09 it/sec] INFO - 11:42:46: 19%|█▉ | 19/100 [00:00<00:00, 3030.57 it/sec] INFO - 11:42:46: 20%|██ | 20/100 [00:00<00:00, 3037.92 it/sec] INFO - 11:42:46: 21%|██ | 21/100 [00:00<00:00, 3052.52 it/sec] INFO - 11:42:46: 22%|██▏ | 22/100 [00:00<00:00, 3053.13 it/sec] INFO - 11:42:46: 23%|██▎ | 23/100 [00:00<00:00, 3056.98 it/sec] INFO - 11:42:46: 24%|██▍ | 24/100 [00:00<00:00, 3068.44 it/sec] INFO - 11:42:46: 25%|██▌ | 25/100 [00:00<00:00, 3073.20 it/sec] INFO - 11:42:46: 26%|██▌ | 26/100 [00:00<00:00, 3082.48 it/sec] INFO - 11:42:46: 27%|██▋ | 27/100 [00:00<00:00, 3088.42 it/sec] INFO - 11:42:46: 28%|██▊ | 28/100 [00:00<00:00, 3099.92 it/sec] INFO - 11:42:46: 29%|██▉ | 29/100 [00:00<00:00, 3101.11 it/sec] INFO - 11:42:46: 30%|███ | 30/100 [00:00<00:00, 3108.50 it/sec] INFO - 11:42:46: 31%|███ | 31/100 [00:00<00:00, 3117.84 it/sec] INFO - 11:42:46: 32%|███▏ | 32/100 [00:00<00:00, 3119.24 it/sec] INFO - 11:42:46: 33%|███▎ | 33/100 [00:00<00:00, 3126.19 it/sec] INFO - 11:42:46: 34%|███▍ | 34/100 [00:00<00:00, 3134.83 it/sec] INFO - 11:42:46: 35%|███▌ | 35/100 [00:00<00:00, 3133.02 it/sec] INFO - 11:42:46: 36%|███▌ | 36/100 [00:00<00:00, 3132.94 it/sec] INFO - 11:42:46: 37%|███▋ | 37/100 [00:00<00:00, 3139.07 it/sec] INFO - 11:42:46: 38%|███▊ | 38/100 [00:00<00:00, 3141.18 it/sec] INFO - 11:42:46: 39%|███▉ | 39/100 [00:00<00:00, 3146.33 it/sec] INFO - 11:42:46: 40%|████ | 40/100 [00:00<00:00, 3149.65 it/sec] INFO - 11:42:46: 41%|████ | 41/100 [00:00<00:00, 3112.80 it/sec] INFO - 11:42:46: 42%|████▏ | 42/100 [00:00<00:00, 3107.93 it/sec] INFO - 11:42:46: 43%|████▎ | 43/100 [00:00<00:00, 3110.16 it/sec] INFO - 11:42:46: 44%|████▍ | 44/100 [00:00<00:00, 3109.46 it/sec] INFO - 11:42:46: 45%|████▌ | 45/100 [00:00<00:00, 3112.89 it/sec] INFO - 11:42:46: 46%|████▌ | 46/100 [00:00<00:00, 3117.94 it/sec] INFO - 11:42:46: 47%|████▋ | 47/100 [00:00<00:00, 3116.72 it/sec] INFO - 11:42:46: 48%|████▊ | 48/100 [00:00<00:00, 3117.96 it/sec] INFO - 11:42:46: 49%|████▉ | 49/100 [00:00<00:00, 3122.52 it/sec] INFO - 11:42:46: 50%|█████ | 50/100 [00:00<00:00, 3127.14 it/sec] INFO - 11:42:46: 51%|█████ | 51/100 [00:00<00:00, 3126.51 it/sec] INFO - 11:42:46: 52%|█████▏ | 52/100 [00:00<00:00, 3126.53 it/sec] INFO - 11:42:46: 53%|█████▎ | 53/100 [00:00<00:00, 3130.03 it/sec] INFO - 11:42:46: 54%|█████▍ | 54/100 [00:00<00:00, 3130.64 it/sec] INFO - 11:42:46: 55%|█████▌ | 55/100 [00:00<00:00, 3134.76 it/sec] INFO - 11:42:46: 56%|█████▌ | 56/100 [00:00<00:00, 3140.33 it/sec] INFO - 11:42:46: 57%|█████▋ | 57/100 [00:00<00:00, 3141.43 it/sec] INFO - 11:42:46: 58%|█████▊ | 58/100 [00:00<00:00, 3142.98 it/sec] INFO - 11:42:46: 59%|█████▉ | 59/100 [00:00<00:00, 3146.71 it/sec] INFO - 11:42:46: 60%|██████ | 60/100 [00:00<00:00, 3150.73 it/sec] INFO - 11:42:46: 61%|██████ | 61/100 [00:00<00:00, 3150.12 it/sec] INFO - 11:42:46: 62%|██████▏ | 62/100 [00:00<00:00, 3154.03 it/sec] INFO - 11:42:46: 63%|██████▎ | 63/100 [00:00<00:00, 3158.59 it/sec] INFO - 11:42:46: 64%|██████▍ | 64/100 [00:00<00:00, 3159.44 it/sec] INFO - 11:42:46: 65%|██████▌ | 65/100 [00:00<00:00, 3161.18 it/sec] INFO - 11:42:46: 66%|██████▌ | 66/100 [00:00<00:00, 3163.85 it/sec] INFO - 11:42:46: 67%|██████▋ | 67/100 [00:00<00:00, 3162.77 it/sec] INFO - 11:42:46: 68%|██████▊ | 68/100 [00:00<00:00, 3164.53 it/sec] INFO - 11:42:46: 69%|██████▉ | 69/100 [00:00<00:00, 3168.32 it/sec] INFO - 11:42:46: 70%|███████ | 70/100 [00:00<00:00, 3170.16 it/sec] INFO - 11:42:46: 71%|███████ | 71/100 [00:00<00:00, 3172.76 it/sec] INFO - 11:42:46: 72%|███████▏ | 72/100 [00:00<00:00, 3175.93 it/sec] INFO - 11:42:46: 73%|███████▎ | 73/100 [00:00<00:00, 3179.12 it/sec] INFO - 11:42:46: 74%|███████▍ | 74/100 [00:00<00:00, 3179.36 it/sec] INFO - 11:42:46: 75%|███████▌ | 75/100 [00:00<00:00, 3181.78 it/sec] INFO - 11:42:46: 76%|███████▌ | 76/100 [00:00<00:00, 3185.35 it/sec] INFO - 11:42:46: 77%|███████▋ | 77/100 [00:00<00:00, 3185.37 it/sec] INFO - 11:42:46: 78%|███████▊ | 78/100 [00:00<00:00, 3187.91 it/sec] INFO - 11:42:46: 79%|███████▉ | 79/100 [00:00<00:00, 3189.19 it/sec] INFO - 11:42:46: 80%|████████ | 80/100 [00:00<00:00, 3192.68 it/sec] INFO - 11:42:46: 81%|████████ | 81/100 [00:00<00:00, 3192.52 it/sec] INFO - 11:42:46: 82%|████████▏ | 82/100 [00:00<00:00, 3195.48 it/sec] INFO - 11:42:46: 83%|████████▎ | 83/100 [00:00<00:00, 3198.96 it/sec] INFO - 11:42:46: 84%|████████▍ | 84/100 [00:00<00:00, 3198.50 it/sec] INFO - 11:42:46: 85%|████████▌ | 85/100 [00:00<00:00, 3200.18 it/sec] INFO - 11:42:46: 86%|████████▌ | 86/100 [00:00<00:00, 3202.58 it/sec] INFO - 11:42:46: 87%|████████▋ | 87/100 [00:00<00:00, 3200.86 it/sec] INFO - 11:42:46: 88%|████████▊ | 88/100 [00:00<00:00, 3201.79 it/sec] INFO - 11:42:46: 89%|████████▉ | 89/100 [00:00<00:00, 3204.29 it/sec] INFO - 11:42:46: 90%|█████████ | 90/100 [00:00<00:00, 3205.46 it/sec] INFO - 11:42:46: 91%|█████████ | 91/100 [00:00<00:00, 3206.12 it/sec] INFO - 11:42:46: 92%|█████████▏| 92/100 [00:00<00:00, 3206.71 it/sec] INFO - 11:42:46: 93%|█████████▎| 93/100 [00:00<00:00, 3208.05 it/sec] INFO - 11:42:46: 94%|█████████▍| 94/100 [00:00<00:00, 3207.41 it/sec] INFO - 11:42:46: 95%|█████████▌| 95/100 [00:00<00:00, 3209.78 it/sec] INFO - 11:42:46: 96%|█████████▌| 96/100 [00:00<00:00, 3208.85 it/sec] INFO - 11:42:46: 97%|█████████▋| 97/100 [00:00<00:00, 3207.44 it/sec] INFO - 11:42:46: 98%|█████████▊| 98/100 [00:00<00:00, 3207.25 it/sec] INFO - 11:42:46: 99%|█████████▉| 99/100 [00:00<00:00, 3204.38 it/sec] INFO - 11:42:46: 100%|██████████| 100/100 [00:00<00:00, 3200.73 it/sec] INFO - 11:42:46: *** End Sampling execution (time: 0:00:00.032584) *** .. GENERATED FROM PYTHON SOURCE LINES 109-116 Settings -------- Number of estimators ~~~~~~~~~~~~~~~~~~~~ The main hyperparameter of random forest regression is the number of trees in the forest (default: 100). Here is a comparison when increasing and decreasing this number: .. GENERATED FROM PYTHON SOURCE LINES 116-134 .. code-block:: Python model = create_regression_model( "RandomForestRegressor", training_dataset, n_estimators=10 ) model.learn() predicted_output_data_1 = model.predict(input_data).ravel() model = create_regression_model( "RandomForestRegressor", training_dataset, n_estimators=1000 ) model.learn() predicted_output_data_2 = model.predict(input_data).ravel() plt.plot(input_data.ravel(), reference_output_data, label="Reference") plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics") plt.plot(input_data.ravel(), predicted_output_data_1, label="Regression - 10 trees") plt.plot(input_data.ravel(), predicted_output_data_2, label="Regression - 1000 trees") plt.grid() plt.legend() plt.show() .. image-sg:: /examples/mlearning/regression_model/images/sphx_glr_plot_random_forest_regression_002.png :alt: plot random forest regression :srcset: /examples/mlearning/regression_model/images/sphx_glr_plot_random_forest_regression_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 135-145 Others ------ The ``RandomForestRegressor`` class of scikit-learn has a lot of settings (`read more `__), and we have chosen to exhibit only ``n_estimators``. However, any argument of ``RandomForestRegressor`` can be set using the dictionary ``parameters``. For example, we can impose a minimum of two samples per leaf: .. GENERATED FROM PYTHON SOURCE LINES 145-156 .. code-block:: Python model = create_regression_model( "RandomForestRegressor", training_dataset, parameters={"min_samples_leaf": 2} ) model.learn() predicted_output_data_ = model.predict(input_data).ravel() plt.plot(input_data.ravel(), reference_output_data, label="Reference") plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics") plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - 2 samples") plt.grid() plt.legend() plt.show() .. image-sg:: /examples/mlearning/regression_model/images/sphx_glr_plot_random_forest_regression_003.png :alt: plot random forest regression :srcset: /examples/mlearning/regression_model/images/sphx_glr_plot_random_forest_regression_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.443 seconds) .. _sphx_glr_download_examples_mlearning_regression_model_plot_random_forest_regression.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_random_forest_regression.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_random_forest_regression.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_random_forest_regression.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_