Note
Go to the end to download the full example code.
Random forest#
A RandomForestRegressor
is a random forest model
based on scikit-learn.
from __future__ import annotations
from matplotlib import pyplot as plt
from numpy import array
from gemseo import configure_logger
from gemseo import create_design_space
from gemseo import create_discipline
from gemseo import sample_disciplines
from gemseo.mlearning import create_regression_model
configure_logger()
<RootLogger root (INFO)>
Problem#
In this example,
we represent the function \(f(x)=(6x-2)^2\sin(12x-4)\) [FSK08]
by the AnalyticDiscipline
discipline = create_discipline(
"AnalyticDiscipline",
name="f",
expressions={"y": "(6*x-2)**2*sin(12*x-4)"},
)
and seek to approximate it over the input space
input_space = create_design_space()
input_space.add_variable("x", lower_bound=0.0, upper_bound=1.0)
To do this, we create a training dataset with 6 equispaced points:
training_dataset = sample_disciplines(
[discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=6
)
WARNING - 08:24:04: No coupling in MDA, switching chain_linearize to True.
INFO - 08:24:04: *** Start Sampling execution ***
INFO - 08:24:04: Sampling
INFO - 08:24:04: Disciplines: f
INFO - 08:24:04: MDO formulation: MDF
INFO - 08:24:04: Running the algorithm PYDOE_FULLFACT:
INFO - 08:24:04: 17%|█▋ | 1/6 [00:00<00:00, 575.98 it/sec]
INFO - 08:24:04: 33%|███▎ | 2/6 [00:00<00:00, 923.55 it/sec]
INFO - 08:24:04: 50%|█████ | 3/6 [00:00<00:00, 1180.39 it/sec]
INFO - 08:24:04: 67%|██████▋ | 4/6 [00:00<00:00, 1389.53 it/sec]
INFO - 08:24:04: 83%|████████▎ | 5/6 [00:00<00:00, 1552.64 it/sec]
INFO - 08:24:04: 100%|██████████| 6/6 [00:00<00:00, 1688.30 it/sec]
INFO - 08:24:04: *** End Sampling execution (time: 0:00:00.004644) ***
Basics#
Training#
Then, we train an random forest regression model from these samples:
model = create_regression_model("RandomForestRegressor", training_dataset)
model.learn()
Prediction#
Once it is built, we can predict the output value of \(f\) at a new input point:
input_value = {"x": array([0.65])}
output_value = model.predict(input_value)
output_value
{'y': array([-0.88837697])}
but cannot predict its Jacobian value:
try:
model.predict_jacobian(input_value)
except NotImplementedError:
print("The derivatives are not available for RandomForestRegressor.")
The derivatives are not available for RandomForestRegressor.
Plotting#
You can see that the random forest model is pretty good on the left, but bad on the right:
test_dataset = sample_disciplines(
[discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=100
)
input_data = test_dataset.get_view(variable_names=model.input_names).to_numpy()
reference_output_data = test_dataset.get_view(variable_names="y").to_numpy().ravel()
predicted_output_data = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.grid()
plt.legend()
plt.show()

WARNING - 08:24:04: No coupling in MDA, switching chain_linearize to True.
INFO - 08:24:04: *** Start Sampling execution ***
INFO - 08:24:04: Sampling
INFO - 08:24:04: Disciplines: f
INFO - 08:24:04: MDO formulation: MDF
INFO - 08:24:04: Running the algorithm PYDOE_FULLFACT:
INFO - 08:24:04: 1%| | 1/100 [00:00<00:00, 2874.78 it/sec]
INFO - 08:24:04: 2%|▏ | 2/100 [00:00<00:00, 2747.66 it/sec]
INFO - 08:24:04: 3%|▎ | 3/100 [00:00<00:00, 2749.76 it/sec]
INFO - 08:24:04: 4%|▍ | 4/100 [00:00<00:00, 2825.87 it/sec]
INFO - 08:24:04: 5%|▌ | 5/100 [00:00<00:00, 2886.65 it/sec]
INFO - 08:24:04: 6%|▌ | 6/100 [00:00<00:00, 2848.10 it/sec]
INFO - 08:24:04: 7%|▋ | 7/100 [00:00<00:00, 2904.07 it/sec]
INFO - 08:24:04: 8%|▊ | 8/100 [00:00<00:00, 2958.94 it/sec]
INFO - 08:24:04: 9%|▉ | 9/100 [00:00<00:00, 2966.27 it/sec]
INFO - 08:24:04: 10%|█ | 10/100 [00:00<00:00, 2995.72 it/sec]
INFO - 08:24:04: 11%|█ | 11/100 [00:00<00:00, 3023.81 it/sec]
INFO - 08:24:04: 12%|█▏ | 12/100 [00:00<00:00, 3029.47 it/sec]
INFO - 08:24:04: 13%|█▎ | 13/100 [00:00<00:00, 3052.62 it/sec]
INFO - 08:24:04: 14%|█▍ | 14/100 [00:00<00:00, 3079.68 it/sec]
INFO - 08:24:04: 15%|█▌ | 15/100 [00:00<00:00, 3084.80 it/sec]
INFO - 08:24:04: 16%|█▌ | 16/100 [00:00<00:00, 3097.00 it/sec]
INFO - 08:24:04: 17%|█▋ | 17/100 [00:00<00:00, 3114.63 it/sec]
INFO - 08:24:04: 18%|█▊ | 18/100 [00:00<00:00, 3116.12 it/sec]
INFO - 08:24:04: 19%|█▉ | 19/100 [00:00<00:00, 3116.73 it/sec]
INFO - 08:24:04: 20%|██ | 20/100 [00:00<00:00, 3123.43 it/sec]
INFO - 08:24:04: 21%|██ | 21/100 [00:00<00:00, 3131.75 it/sec]
INFO - 08:24:04: 22%|██▏ | 22/100 [00:00<00:00, 3125.52 it/sec]
INFO - 08:24:04: 23%|██▎ | 23/100 [00:00<00:00, 3136.49 it/sec]
INFO - 08:24:04: 24%|██▍ | 24/100 [00:00<00:00, 3147.40 it/sec]
INFO - 08:24:04: 25%|██▌ | 25/100 [00:00<00:00, 3143.31 it/sec]
INFO - 08:24:04: 26%|██▌ | 26/100 [00:00<00:00, 3149.60 it/sec]
INFO - 08:24:04: 27%|██▋ | 27/100 [00:00<00:00, 3158.36 it/sec]
INFO - 08:24:04: 28%|██▊ | 28/100 [00:00<00:00, 3145.76 it/sec]
INFO - 08:24:04: 29%|██▉ | 29/100 [00:00<00:00, 3148.80 it/sec]
INFO - 08:24:04: 30%|███ | 30/100 [00:00<00:00, 3157.89 it/sec]
INFO - 08:24:04: 31%|███ | 31/100 [00:00<00:00, 3157.21 it/sec]
INFO - 08:24:04: 32%|███▏ | 32/100 [00:00<00:00, 3163.35 it/sec]
INFO - 08:24:04: 33%|███▎ | 33/100 [00:00<00:00, 3170.81 it/sec]
INFO - 08:24:04: 34%|███▍ | 34/100 [00:00<00:00, 3178.35 it/sec]
INFO - 08:24:04: 35%|███▌ | 35/100 [00:00<00:00, 3177.23 it/sec]
INFO - 08:24:04: 36%|███▌ | 36/100 [00:00<00:00, 3183.47 it/sec]
INFO - 08:24:04: 37%|███▋ | 37/100 [00:00<00:00, 3191.55 it/sec]
INFO - 08:24:04: 38%|███▊ | 38/100 [00:00<00:00, 3191.05 it/sec]
INFO - 08:24:04: 39%|███▉ | 39/100 [00:00<00:00, 3196.13 it/sec]
INFO - 08:24:04: 40%|████ | 40/100 [00:00<00:00, 3202.98 it/sec]
INFO - 08:24:04: 41%|████ | 41/100 [00:00<00:00, 3165.51 it/sec]
INFO - 08:24:04: 42%|████▏ | 42/100 [00:00<00:00, 3159.66 it/sec]
INFO - 08:24:04: 43%|████▎ | 43/100 [00:00<00:00, 3159.80 it/sec]
INFO - 08:24:04: 44%|████▍ | 44/100 [00:00<00:00, 3156.09 it/sec]
INFO - 08:24:04: 45%|████▌ | 45/100 [00:00<00:00, 3157.04 it/sec]
INFO - 08:24:04: 46%|████▌ | 46/100 [00:00<00:00, 3157.22 it/sec]
INFO - 08:24:04: 47%|████▋ | 47/100 [00:00<00:00, 3155.18 it/sec]
INFO - 08:24:04: 48%|████▊ | 48/100 [00:00<00:00, 3158.26 it/sec]
INFO - 08:24:04: 49%|████▉ | 49/100 [00:00<00:00, 3163.03 it/sec]
INFO - 08:24:04: 50%|█████ | 50/100 [00:00<00:00, 3164.37 it/sec]
INFO - 08:24:04: 51%|█████ | 51/100 [00:00<00:00, 3168.09 it/sec]
INFO - 08:24:04: 52%|█████▏ | 52/100 [00:00<00:00, 3172.05 it/sec]
INFO - 08:24:04: 53%|█████▎ | 53/100 [00:00<00:00, 3176.41 it/sec]
INFO - 08:24:04: 54%|█████▍ | 54/100 [00:00<00:00, 3174.65 it/sec]
INFO - 08:24:04: 55%|█████▌ | 55/100 [00:00<00:00, 3178.60 it/sec]
INFO - 08:24:04: 56%|█████▌ | 56/100 [00:00<00:00, 3183.32 it/sec]
INFO - 08:24:04: 57%|█████▋ | 57/100 [00:00<00:00, 3181.22 it/sec]
INFO - 08:24:04: 58%|█████▊ | 58/100 [00:00<00:00, 3181.45 it/sec]
INFO - 08:24:04: 59%|█████▉ | 59/100 [00:00<00:00, 3184.78 it/sec]
INFO - 08:24:04: 60%|██████ | 60/100 [00:00<00:00, 3180.27 it/sec]
INFO - 08:24:04: 61%|██████ | 61/100 [00:00<00:00, 3179.08 it/sec]
INFO - 08:24:04: 62%|██████▏ | 62/100 [00:00<00:00, 3181.39 it/sec]
INFO - 08:24:04: 63%|██████▎ | 63/100 [00:00<00:00, 3181.25 it/sec]
INFO - 08:24:04: 64%|██████▍ | 64/100 [00:00<00:00, 3183.80 it/sec]
INFO - 08:24:04: 65%|██████▌ | 65/100 [00:00<00:00, 3187.79 it/sec]
INFO - 08:24:04: 66%|██████▌ | 66/100 [00:00<00:00, 3192.53 it/sec]
INFO - 08:24:04: 67%|██████▋ | 67/100 [00:00<00:00, 3189.98 it/sec]
INFO - 08:24:04: 68%|██████▊ | 68/100 [00:00<00:00, 3192.30 it/sec]
INFO - 08:24:04: 69%|██████▉ | 69/100 [00:00<00:00, 3195.75 it/sec]
INFO - 08:24:04: 70%|███████ | 70/100 [00:00<00:00, 3196.39 it/sec]
INFO - 08:24:04: 71%|███████ | 71/100 [00:00<00:00, 3199.32 it/sec]
INFO - 08:24:04: 72%|███████▏ | 72/100 [00:00<00:00, 3202.54 it/sec]
INFO - 08:24:04: 73%|███████▎ | 73/100 [00:00<00:00, 3203.77 it/sec]
INFO - 08:24:04: 74%|███████▍ | 74/100 [00:00<00:00, 3206.22 it/sec]
INFO - 08:24:04: 75%|███████▌ | 75/100 [00:00<00:00, 3209.01 it/sec]
INFO - 08:24:04: 76%|███████▌ | 76/100 [00:00<00:00, 3212.31 it/sec]
INFO - 08:24:04: 77%|███████▋ | 77/100 [00:00<00:00, 3212.27 it/sec]
INFO - 08:24:04: 78%|███████▊ | 78/100 [00:00<00:00, 3215.23 it/sec]
INFO - 08:24:04: 79%|███████▉ | 79/100 [00:00<00:00, 3218.62 it/sec]
INFO - 08:24:04: 80%|████████ | 80/100 [00:00<00:00, 3211.72 it/sec]
INFO - 08:24:04: 81%|████████ | 81/100 [00:00<00:00, 3207.17 it/sec]
INFO - 08:24:04: 82%|████████▏ | 82/100 [00:00<00:00, 3207.82 it/sec]
INFO - 08:24:04: 83%|████████▎ | 83/100 [00:00<00:00, 3206.74 it/sec]
INFO - 08:24:04: 84%|████████▍ | 84/100 [00:00<00:00, 3208.46 it/sec]
INFO - 08:24:04: 85%|████████▌ | 85/100 [00:00<00:00, 3210.96 it/sec]
INFO - 08:24:04: 86%|████████▌ | 86/100 [00:00<00:00, 3211.34 it/sec]
INFO - 08:24:04: 87%|████████▋ | 87/100 [00:00<00:00, 3212.10 it/sec]
INFO - 08:24:04: 88%|████████▊ | 88/100 [00:00<00:00, 3213.30 it/sec]
INFO - 08:24:04: 89%|████████▉ | 89/100 [00:00<00:00, 3215.63 it/sec]
INFO - 08:24:04: 90%|█████████ | 90/100 [00:00<00:00, 3214.85 it/sec]
INFO - 08:24:04: 91%|█████████ | 91/100 [00:00<00:00, 3217.06 it/sec]
INFO - 08:24:04: 92%|█████████▏| 92/100 [00:00<00:00, 3220.03 it/sec]
INFO - 08:24:04: 93%|█████████▎| 93/100 [00:00<00:00, 3214.21 it/sec]
INFO - 08:24:04: 94%|█████████▍| 94/100 [00:00<00:00, 3212.06 it/sec]
INFO - 08:24:04: 95%|█████████▌| 95/100 [00:00<00:00, 3213.66 it/sec]
INFO - 08:24:04: 96%|█████████▌| 96/100 [00:00<00:00, 3213.41 it/sec]
INFO - 08:24:04: 97%|█████████▋| 97/100 [00:00<00:00, 3215.14 it/sec]
INFO - 08:24:04: 98%|█████████▊| 98/100 [00:00<00:00, 3217.75 it/sec]
INFO - 08:24:04: 99%|█████████▉| 99/100 [00:00<00:00, 3217.96 it/sec]
INFO - 08:24:04: 100%|██████████| 100/100 [00:00<00:00, 3219.13 it/sec]
INFO - 08:24:04: *** End Sampling execution (time: 0:00:00.032376) ***
Settings#
Number of estimators#
The main hyperparameter of random forest regression is the number of trees in the forest (default: 100). Here is a comparison when increasing and decreasing this number:
model = create_regression_model(
"RandomForestRegressor", training_dataset, n_estimators=10
)
model.learn()
predicted_output_data_1 = model.predict(input_data).ravel()
model = create_regression_model(
"RandomForestRegressor", training_dataset, n_estimators=1000
)
model.learn()
predicted_output_data_2 = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_1, label="Regression - 10 trees")
plt.plot(input_data.ravel(), predicted_output_data_2, label="Regression - 1000 trees")
plt.grid()
plt.legend()
plt.show()

Others#
The RandomForestRegressor
class of scikit-learn has a lot of settings
(read more),
and we have chosen to exhibit only n_estimators
.
However,
any argument of RandomForestRegressor
can be set
using the dictionary parameters
.
For example,
we can impose a minimum of two samples per leaf:
model = create_regression_model(
"RandomForestRegressor", training_dataset, parameters={"min_samples_leaf": 2}
)
model.learn()
predicted_output_data_ = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - 2 samples")
plt.grid()
plt.legend()
plt.show()

Total running time of the script: (0 minutes 1.448 seconds)