Random forest#

A RandomForestRegressor is a random forest model based on scikit-learn.

from __future__ import annotations

from matplotlib import pyplot as plt
from numpy import array

from gemseo import create_design_space
from gemseo import create_discipline
from gemseo import sample_disciplines
from gemseo.mlearning import create_regression_model

Problem#

In this example, we represent the function \(f(x)=(6x-2)^2\sin(12x-4)\) [FSK08] by the AnalyticDiscipline

discipline = create_discipline(
    "AnalyticDiscipline",
    name="f",
    expressions={"y": "(6*x-2)**2*sin(12*x-4)"},
)

and seek to approximate it over the input space

input_space = create_design_space()
input_space.add_variable("x", lower_bound=0.0, upper_bound=1.0)

To do this, we create a training dataset with 6 equispaced points:

training_dataset = sample_disciplines(
    [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=6
)
INFO - 16:21:49: *** Start Sampling execution ***
INFO - 16:21:49: Sampling
INFO - 16:21:49:    Disciplines: f
INFO - 16:21:49:    MDO formulation: MDF
INFO - 16:21:49: Running the algorithm PYDOE_FULLFACT:
INFO - 16:21:49:     17%|█▋        | 1/6 [00:00<00:00, 675.85 it/sec]
INFO - 16:21:49:     33%|███▎      | 2/6 [00:00<00:00, 1103.62 it/sec]
INFO - 16:21:49:     50%|█████     | 3/6 [00:00<00:00, 1427.44 it/sec]
INFO - 16:21:49:     67%|██████▋   | 4/6 [00:00<00:00, 1691.93 it/sec]
INFO - 16:21:49:     83%|████████▎ | 5/6 [00:00<00:00, 1906.50 it/sec]
INFO - 16:21:49:    100%|██████████| 6/6 [00:00<00:00, 2045.34 it/sec]
INFO - 16:21:49: *** End Sampling execution ***

Basics#

Training#

Then, we train an random forest regression model from these samples:

model = create_regression_model("RandomForestRegressor", training_dataset)
model.learn()

Prediction#

Once it is built, we can predict the output value of \(f\) at a new input point:

input_value = {"x": array([0.65])}
output_value = model.predict(input_value)
output_value
{'y': array([-0.88837697])}

but cannot predict its Jacobian value:

try:
    model.predict_jacobian(input_value)
except NotImplementedError:
    print("The derivatives are not available for RandomForestRegressor.")
The derivatives are not available for RandomForestRegressor.

Plotting#

You can see that the random forest model is pretty good on the left, but bad on the right:

test_dataset = sample_disciplines(
    [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=100
)
input_data = test_dataset.get_view(variable_names=model.input_names).to_numpy()
reference_output_data = test_dataset.get_view(variable_names="y").to_numpy().ravel()
predicted_output_data = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.grid()
plt.legend()
plt.show()
plot random forest regression
INFO - 16:21:49: *** Start Sampling execution ***
INFO - 16:21:49: Sampling
INFO - 16:21:49:    Disciplines: f
INFO - 16:21:49:    MDO formulation: MDF
INFO - 16:21:49: Running the algorithm PYDOE_FULLFACT:
INFO - 16:21:49:      1%|          | 1/100 [00:00<00:00, 3809.54 it/sec]
INFO - 16:21:49:      2%|▏         | 2/100 [00:00<00:00, 3493.80 it/sec]
INFO - 16:21:49:      3%|▎         | 3/100 [00:00<00:00, 3602.32 it/sec]
INFO - 16:21:49:      4%|▍         | 4/100 [00:00<00:00, 3697.87 it/sec]
INFO - 16:21:49:      5%|▌         | 5/100 [00:00<00:00, 3779.33 it/sec]
INFO - 16:21:49:      6%|▌         | 6/100 [00:00<00:00, 3791.75 it/sec]
INFO - 16:21:49:      7%|▋         | 7/100 [00:00<00:00, 3840.44 it/sec]
INFO - 16:21:49:      8%|▊         | 8/100 [00:00<00:00, 3897.15 it/sec]
INFO - 16:21:49:      9%|▉         | 9/100 [00:00<00:00, 3944.07 it/sec]
INFO - 16:21:49:     10%|█         | 10/100 [00:00<00:00, 3942.39 it/sec]
INFO - 16:21:49:     11%|█         | 11/100 [00:00<00:00, 3956.89 it/sec]
INFO - 16:21:49:     12%|█▏        | 12/100 [00:00<00:00, 3990.46 it/sec]
INFO - 16:21:49:     13%|█▎        | 13/100 [00:00<00:00, 4022.87 it/sec]
INFO - 16:21:49:     14%|█▍        | 14/100 [00:00<00:00, 4045.21 it/sec]
INFO - 16:21:49:     15%|█▌        | 15/100 [00:00<00:00, 4040.76 it/sec]
INFO - 16:21:49:     16%|█▌        | 16/100 [00:00<00:00, 4062.53 it/sec]
INFO - 16:21:49:     17%|█▋        | 17/100 [00:00<00:00, 4081.23 it/sec]
INFO - 16:21:49:     18%|█▊        | 18/100 [00:00<00:00, 4099.56 it/sec]
INFO - 16:21:49:     19%|█▉        | 19/100 [00:00<00:00, 4096.84 it/sec]
INFO - 16:21:49:     20%|██        | 20/100 [00:00<00:00, 4108.04 it/sec]
INFO - 16:21:49:     21%|██        | 21/100 [00:00<00:00, 4116.10 it/sec]
INFO - 16:21:49:     22%|██▏       | 22/100 [00:00<00:00, 4123.46 it/sec]
INFO - 16:21:49:     23%|██▎       | 23/100 [00:00<00:00, 4124.90 it/sec]
INFO - 16:21:49:     24%|██▍       | 24/100 [00:00<00:00, 4136.39 it/sec]
INFO - 16:21:49:     25%|██▌       | 25/100 [00:00<00:00, 4152.45 it/sec]
INFO - 16:21:49:     26%|██▌       | 26/100 [00:00<00:00, 4166.42 it/sec]
INFO - 16:21:49:     27%|██▋       | 27/100 [00:00<00:00, 4179.75 it/sec]
INFO - 16:21:49:     28%|██▊       | 28/100 [00:00<00:00, 4175.96 it/sec]
INFO - 16:21:49:     29%|██▉       | 29/100 [00:00<00:00, 4186.65 it/sec]
INFO - 16:21:49:     30%|███       | 30/100 [00:00<00:00, 4198.36 it/sec]
INFO - 16:21:49:     31%|███       | 31/100 [00:00<00:00, 4209.38 it/sec]
INFO - 16:21:49:     32%|███▏      | 32/100 [00:00<00:00, 4153.68 it/sec]
INFO - 16:21:49:     33%|███▎      | 33/100 [00:00<00:00, 4156.89 it/sec]
INFO - 16:21:49:     34%|███▍      | 34/100 [00:00<00:00, 4165.88 it/sec]
INFO - 16:21:49:     35%|███▌      | 35/100 [00:00<00:00, 4174.86 it/sec]
INFO - 16:21:49:     36%|███▌      | 36/100 [00:00<00:00, 4174.13 it/sec]
INFO - 16:21:49:     37%|███▋      | 37/100 [00:00<00:00, 4180.52 it/sec]
INFO - 16:21:49:     38%|███▊      | 38/100 [00:00<00:00, 4185.05 it/sec]
INFO - 16:21:49:     39%|███▉      | 39/100 [00:00<00:00, 4192.80 it/sec]
INFO - 16:21:49:     40%|████      | 40/100 [00:00<00:00, 4193.05 it/sec]
INFO - 16:21:49:     41%|████      | 41/100 [00:00<00:00, 4198.20 it/sec]
INFO - 16:21:49:     42%|████▏     | 42/100 [00:00<00:00, 4205.92 it/sec]
INFO - 16:21:49:     43%|████▎     | 43/100 [00:00<00:00, 4212.33 it/sec]
INFO - 16:21:49:     44%|████▍     | 44/100 [00:00<00:00, 4219.53 it/sec]
INFO - 16:21:49:     45%|████▌     | 45/100 [00:00<00:00, 4216.61 it/sec]
INFO - 16:21:49:     46%|████▌     | 46/100 [00:00<00:00, 4218.51 it/sec]
INFO - 16:21:49:     47%|████▋     | 47/100 [00:00<00:00, 4224.50 it/sec]
INFO - 16:21:49:     48%|████▊     | 48/100 [00:00<00:00, 4230.97 it/sec]
INFO - 16:21:49:     49%|████▉     | 49/100 [00:00<00:00, 4230.39 it/sec]
INFO - 16:21:49:     50%|█████     | 50/100 [00:00<00:00, 4234.02 it/sec]
INFO - 16:21:49:     51%|█████     | 51/100 [00:00<00:00, 4240.03 it/sec]
INFO - 16:21:49:     52%|█████▏    | 52/100 [00:00<00:00, 4245.91 it/sec]
INFO - 16:21:49:     53%|█████▎    | 53/100 [00:00<00:00, 4249.71 it/sec]
INFO - 16:21:49:     54%|█████▍    | 54/100 [00:00<00:00, 4246.04 it/sec]
INFO - 16:21:49:     55%|█████▌    | 55/100 [00:00<00:00, 4250.25 it/sec]
INFO - 16:21:49:     56%|█████▌    | 56/100 [00:00<00:00, 4252.70 it/sec]
INFO - 16:21:49:     57%|█████▋    | 57/100 [00:00<00:00, 4257.95 it/sec]
INFO - 16:21:49:     58%|█████▊    | 58/100 [00:00<00:00, 4256.76 it/sec]
INFO - 16:21:49:     59%|█████▉    | 59/100 [00:00<00:00, 4260.74 it/sec]
INFO - 16:21:49:     60%|██████    | 60/100 [00:00<00:00, 4266.55 it/sec]
INFO - 16:21:49:     61%|██████    | 61/100 [00:00<00:00, 4270.90 it/sec]
INFO - 16:21:49:     62%|██████▏   | 62/100 [00:00<00:00, 4276.17 it/sec]
INFO - 16:21:49:     63%|██████▎   | 63/100 [00:00<00:00, 4272.50 it/sec]
INFO - 16:21:49:     64%|██████▍   | 64/100 [00:00<00:00, 4276.56 it/sec]
INFO - 16:21:49:     65%|██████▌   | 65/100 [00:00<00:00, 4281.11 it/sec]
INFO - 16:21:49:     66%|██████▌   | 66/100 [00:00<00:00, 4285.14 it/sec]
INFO - 16:21:49:     67%|██████▋   | 67/100 [00:00<00:00, 4283.69 it/sec]
INFO - 16:21:49:     68%|██████▊   | 68/100 [00:00<00:00, 4287.17 it/sec]
INFO - 16:21:49:     69%|██████▉   | 69/100 [00:00<00:00, 4291.90 it/sec]
INFO - 16:21:49:     70%|███████   | 70/100 [00:00<00:00, 4295.74 it/sec]
INFO - 16:21:49:     71%|███████   | 71/100 [00:00<00:00, 4296.20 it/sec]
INFO - 16:21:49:     72%|███████▏  | 72/100 [00:00<00:00, 4298.30 it/sec]
INFO - 16:21:49:     73%|███████▎  | 73/100 [00:00<00:00, 4301.43 it/sec]
INFO - 16:21:49:     74%|███████▍  | 74/100 [00:00<00:00, 4303.34 it/sec]
INFO - 16:21:49:     75%|███████▌  | 75/100 [00:00<00:00, 4307.21 it/sec]
INFO - 16:21:49:     76%|███████▌  | 76/100 [00:00<00:00, 4304.00 it/sec]
INFO - 16:21:49:     77%|███████▋  | 77/100 [00:00<00:00, 4305.23 it/sec]
INFO - 16:21:49:     78%|███████▊  | 78/100 [00:00<00:00, 4308.71 it/sec]
INFO - 16:21:49:     79%|███████▉  | 79/100 [00:00<00:00, 4311.20 it/sec]
INFO - 16:21:49:     80%|████████  | 80/100 [00:00<00:00, 4311.14 it/sec]
INFO - 16:21:49:     81%|████████  | 81/100 [00:00<00:00, 4313.70 it/sec]
INFO - 16:21:49:     82%|████████▏ | 82/100 [00:00<00:00, 4317.73 it/sec]
INFO - 16:21:49:     83%|████████▎ | 83/100 [00:00<00:00, 4321.56 it/sec]
INFO - 16:21:49:     84%|████████▍ | 84/100 [00:00<00:00, 4324.98 it/sec]
INFO - 16:21:49:     85%|████████▌ | 85/100 [00:00<00:00, 4323.87 it/sec]
INFO - 16:21:49:     86%|████████▌ | 86/100 [00:00<00:00, 4326.05 it/sec]
INFO - 16:21:49:     87%|████████▋ | 87/100 [00:00<00:00, 4328.64 it/sec]
INFO - 16:21:49:     88%|████████▊ | 88/100 [00:00<00:00, 4331.79 it/sec]
INFO - 16:21:49:     89%|████████▉ | 89/100 [00:00<00:00, 4330.90 it/sec]
INFO - 16:21:49:     90%|█████████ | 90/100 [00:00<00:00, 4331.91 it/sec]
INFO - 16:21:49:     91%|█████████ | 91/100 [00:00<00:00, 4334.43 it/sec]
INFO - 16:21:49:     92%|█████████▏| 92/100 [00:00<00:00, 4333.35 it/sec]
INFO - 16:21:49:     93%|█████████▎| 93/100 [00:00<00:00, 4335.75 it/sec]
INFO - 16:21:49:     94%|█████████▍| 94/100 [00:00<00:00, 4332.39 it/sec]
INFO - 16:21:49:     95%|█████████▌| 95/100 [00:00<00:00, 4334.28 it/sec]
INFO - 16:21:49:     96%|█████████▌| 96/100 [00:00<00:00, 4337.02 it/sec]
INFO - 16:21:49:     97%|█████████▋| 97/100 [00:00<00:00, 4339.71 it/sec]
INFO - 16:21:49:     98%|█████████▊| 98/100 [00:00<00:00, 4338.77 it/sec]
INFO - 16:21:49:     99%|█████████▉| 99/100 [00:00<00:00, 4340.39 it/sec]
INFO - 16:21:49:    100%|██████████| 100/100 [00:00<00:00, 4294.27 it/sec]
INFO - 16:21:49: *** End Sampling execution ***

Settings#

Number of estimators#

The main hyperparameter of random forest regression is the number of trees in the forest (default: 100). Here is a comparison when increasing and decreasing this number:

model = create_regression_model(
    "RandomForestRegressor", training_dataset, n_estimators=10
)
model.learn()
predicted_output_data_1 = model.predict(input_data).ravel()
model = create_regression_model(
    "RandomForestRegressor", training_dataset, n_estimators=1000
)
model.learn()
predicted_output_data_2 = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_1, label="Regression - 10 trees")
plt.plot(input_data.ravel(), predicted_output_data_2, label="Regression - 1000 trees")
plt.grid()
plt.legend()
plt.show()
plot random forest regression

Others#

The RandomForestRegressor class of scikit-learn has a lot of settings (read more), and we have chosen to exhibit only n_estimators. However, any argument of RandomForestRegressor can be set using the dictionary parameters. For example, we can impose a minimum of two samples per leaf:

model = create_regression_model(
    "RandomForestRegressor", training_dataset, parameters={"min_samples_leaf": 2}
)
model.learn()
predicted_output_data_ = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - 2 samples")
plt.grid()
plt.legend()
plt.show()
plot random forest regression

Total running time of the script: (0 minutes 0.931 seconds)

Gallery generated by Sphinx-Gallery