Polynomial chaos expansion (PCE)#

A PCERegressor is a PCE model based on OpenTURNS.

from __future__ import annotations

from matplotlib import pyplot as plt
from numpy import array

from gemseo import configure_logger
from gemseo import create_discipline
from gemseo import create_parameter_space
from gemseo import sample_disciplines
from gemseo.mlearning import create_regression_model

configure_logger()
<RootLogger root (INFO)>

Problem#

In this example, we represent the function \(f(x)=(6x-2)^2\sin(12x-4)\) [FSK08] by the AnalyticDiscipline

discipline = create_discipline(
    "AnalyticDiscipline",
    name="f",
    expressions={"y": "(6*x-2)**2*sin(12*x-4)"},
)

and seek to approximate it over the input space

input_space = create_parameter_space()
input_space.add_random_variable("x", "OTUniformDistribution")

To do this, we create a training dataset with 6 equispaced points:

training_dataset = sample_disciplines(
    [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=10
)
WARNING - 13:11:20: No coupling in MDA, switching chain_linearize to True.
   INFO - 13:11:20:
   INFO - 13:11:20: *** Start Sampling execution ***
   INFO - 13:11:20: Sampling
   INFO - 13:11:20:    Disciplines: f
   INFO - 13:11:20:    MDO formulation: MDF
   INFO - 13:11:20: Running the algorithm PYDOE_FULLFACT:
   INFO - 13:11:20:     10%|█         | 1/10 [00:00<00:00, 558.87 it/sec]
   INFO - 13:11:20:     20%|██        | 2/10 [00:00<00:00, 897.56 it/sec]
   INFO - 13:11:20:     30%|███       | 3/10 [00:00<00:00, 1163.36 it/sec]
   INFO - 13:11:20:     40%|████      | 4/10 [00:00<00:00, 1377.44 it/sec]
   INFO - 13:11:20:     50%|█████     | 5/10 [00:00<00:00, 1551.26 it/sec]
   INFO - 13:11:20:     60%|██████    | 6/10 [00:00<00:00, 1697.53 it/sec]
   INFO - 13:11:20:     70%|███████   | 7/10 [00:00<00:00, 1802.01 it/sec]
   INFO - 13:11:20:     80%|████████  | 8/10 [00:00<00:00, 1910.08 it/sec]
   INFO - 13:11:20:     90%|█████████ | 9/10 [00:00<00:00, 2000.57 it/sec]
   INFO - 13:11:20:    100%|██████████| 10/10 [00:00<00:00, 2074.95 it/sec]
   INFO - 13:11:20: *** End Sampling execution (time: 0:00:00.006427) ***

Basics#

Training#

Then, we train an PCE regression model from these samples:

model = create_regression_model("PCERegressor", training_dataset)
model.learn()
WARNING - 13:11:20: Remove input data transformation because PCERegressor does not support transformers.

Prediction#

Once it is built, we can predict the output value of \(f\) at a new input point:

input_value = {"x": array([0.65])}
output_value = model.predict(input_value)
output_value
{'y': array([-0.81106394])}

as well as its Jacobian value:

jacobian_value = model.predict_jacobian(input_value)
jacobian_value
{'y': {'x': array([[18.2279622]])}}

Plotting#

Of course, you can see that the quadratic model is no good at all here:

test_dataset = sample_disciplines(
    [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=100
)
input_data = test_dataset.get_view(variable_names=model.input_names).to_numpy()
reference_output_data = test_dataset.get_view(variable_names="y").to_numpy().ravel()
predicted_output_data = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.grid()
plt.legend()
plt.show()
plot pce regression
WARNING - 13:11:20: No coupling in MDA, switching chain_linearize to True.
   INFO - 13:11:20:
   INFO - 13:11:20: *** Start Sampling execution ***
   INFO - 13:11:20: Sampling
   INFO - 13:11:20:    Disciplines: f
   INFO - 13:11:20:    MDO formulation: MDF
   INFO - 13:11:20: Running the algorithm PYDOE_FULLFACT:
   INFO - 13:11:20:      1%|          | 1/100 [00:00<00:00, 2702.52 it/sec]
   INFO - 13:11:20:      2%|▏         | 2/100 [00:00<00:00, 2745.86 it/sec]
   INFO - 13:11:20:      3%|▎         | 3/100 [00:00<00:00, 2871.50 it/sec]
   INFO - 13:11:20:      4%|▍         | 4/100 [00:00<00:00, 2981.03 it/sec]
   INFO - 13:11:20:      5%|▌         | 5/100 [00:00<00:00, 3065.12 it/sec]
   INFO - 13:11:20:      6%|▌         | 6/100 [00:00<00:00, 3132.42 it/sec]
   INFO - 13:11:20:      7%|▋         | 7/100 [00:00<00:00, 3121.43 it/sec]
   INFO - 13:11:20:      8%|▊         | 8/100 [00:00<00:00, 3128.91 it/sec]
   INFO - 13:11:20:      9%|▉         | 9/100 [00:00<00:00, 3167.64 it/sec]
   INFO - 13:11:20:     10%|█         | 10/100 [00:00<00:00, 3205.43 it/sec]
   INFO - 13:11:20:     11%|█         | 11/100 [00:00<00:00, 3218.96 it/sec]
   INFO - 13:11:20:     12%|█▏        | 12/100 [00:00<00:00, 3238.01 it/sec]
   INFO - 13:11:20:     13%|█▎        | 13/100 [00:00<00:00, 3253.73 it/sec]
   INFO - 13:11:20:     14%|█▍        | 14/100 [00:00<00:00, 3273.69 it/sec]
   INFO - 13:11:20:     15%|█▌        | 15/100 [00:00<00:00, 3290.34 it/sec]
   INFO - 13:11:20:     16%|█▌        | 16/100 [00:00<00:00, 3308.30 it/sec]
   INFO - 13:11:20:     17%|█▋        | 17/100 [00:00<00:00, 3325.40 it/sec]
   INFO - 13:11:20:     18%|█▊        | 18/100 [00:00<00:00, 3341.19 it/sec]
   INFO - 13:11:20:     19%|█▉        | 19/100 [00:00<00:00, 3350.65 it/sec]
   INFO - 13:11:20:     20%|██        | 20/100 [00:00<00:00, 3358.80 it/sec]
   INFO - 13:11:20:     21%|██        | 21/100 [00:00<00:00, 3341.82 it/sec]
   INFO - 13:11:20:     22%|██▏       | 22/100 [00:00<00:00, 3340.02 it/sec]
   INFO - 13:11:20:     23%|██▎       | 23/100 [00:00<00:00, 3346.13 it/sec]
   INFO - 13:11:20:     24%|██▍       | 24/100 [00:00<00:00, 3352.43 it/sec]
   INFO - 13:11:20:     25%|██▌       | 25/100 [00:00<00:00, 3344.95 it/sec]
   INFO - 13:11:20:     26%|██▌       | 26/100 [00:00<00:00, 3351.94 it/sec]
   INFO - 13:11:20:     27%|██▋       | 27/100 [00:00<00:00, 3359.33 it/sec]
   INFO - 13:11:20:     28%|██▊       | 28/100 [00:00<00:00, 3368.05 it/sec]
   INFO - 13:11:20:     29%|██▉       | 29/100 [00:00<00:00, 3376.03 it/sec]
   INFO - 13:11:20:     30%|███       | 30/100 [00:00<00:00, 3384.41 it/sec]
   INFO - 13:11:20:     31%|███       | 31/100 [00:00<00:00, 3392.21 it/sec]
   INFO - 13:11:20:     32%|███▏      | 32/100 [00:00<00:00, 3399.47 it/sec]
   INFO - 13:11:20:     33%|███▎      | 33/100 [00:00<00:00, 3403.80 it/sec]
   INFO - 13:11:20:     34%|███▍      | 34/100 [00:00<00:00, 3408.13 it/sec]
   INFO - 13:11:20:     35%|███▌      | 35/100 [00:00<00:00, 3399.66 it/sec]
   INFO - 13:11:20:     36%|███▌      | 36/100 [00:00<00:00, 3401.78 it/sec]
   INFO - 13:11:20:     37%|███▋      | 37/100 [00:00<00:00, 3408.21 it/sec]
   INFO - 13:11:20:     38%|███▊      | 38/100 [00:00<00:00, 3415.41 it/sec]
   INFO - 13:11:20:     39%|███▉      | 39/100 [00:00<00:00, 3415.34 it/sec]
   INFO - 13:11:20:     40%|████      | 40/100 [00:00<00:00, 3418.34 it/sec]
   INFO - 13:11:20:     41%|████      | 41/100 [00:00<00:00, 3422.22 it/sec]
   INFO - 13:11:20:     42%|████▏     | 42/100 [00:00<00:00, 3425.99 it/sec]
   INFO - 13:11:20:     43%|████▎     | 43/100 [00:00<00:00, 3425.94 it/sec]
   INFO - 13:11:20:     44%|████▍     | 44/100 [00:00<00:00, 3425.32 it/sec]
   INFO - 13:11:20:     45%|████▌     | 45/100 [00:00<00:00, 3424.17 it/sec]
   INFO - 13:11:20:     46%|████▌     | 46/100 [00:00<00:00, 3422.83 it/sec]
   INFO - 13:11:20:     47%|████▋     | 47/100 [00:00<00:00, 3421.43 it/sec]
   INFO - 13:11:20:     48%|████▊     | 48/100 [00:00<00:00, 3420.55 it/sec]
   INFO - 13:11:20:     49%|████▉     | 49/100 [00:00<00:00, 3406.84 it/sec]
   INFO - 13:11:20:     50%|█████     | 50/100 [00:00<00:00, 3400.27 it/sec]
   INFO - 13:11:20:     51%|█████     | 51/100 [00:00<00:00, 3402.30 it/sec]
   INFO - 13:11:20:     52%|█████▏    | 52/100 [00:00<00:00, 3402.61 it/sec]
   INFO - 13:11:20:     53%|█████▎    | 53/100 [00:00<00:00, 3405.41 it/sec]
   INFO - 13:11:20:     54%|█████▍    | 54/100 [00:00<00:00, 3408.57 it/sec]
   INFO - 13:11:20:     55%|█████▌    | 55/100 [00:00<00:00, 3411.72 it/sec]
   INFO - 13:11:20:     56%|█████▌    | 56/100 [00:00<00:00, 3413.27 it/sec]
   INFO - 13:11:20:     57%|█████▋    | 57/100 [00:00<00:00, 3416.48 it/sec]
   INFO - 13:11:20:     58%|█████▊    | 58/100 [00:00<00:00, 3420.26 it/sec]
   INFO - 13:11:20:     59%|█████▉    | 59/100 [00:00<00:00, 3423.92 it/sec]
   INFO - 13:11:20:     60%|██████    | 60/100 [00:00<00:00, 3427.19 it/sec]
   INFO - 13:11:20:     61%|██████    | 61/100 [00:00<00:00, 3430.21 it/sec]
   INFO - 13:11:20:     62%|██████▏   | 62/100 [00:00<00:00, 3433.96 it/sec]
   INFO - 13:11:20:     63%|██████▎   | 63/100 [00:00<00:00, 3430.06 it/sec]
   INFO - 13:11:20:     64%|██████▍   | 64/100 [00:00<00:00, 3430.57 it/sec]
   INFO - 13:11:20:     65%|██████▌   | 65/100 [00:00<00:00, 3405.87 it/sec]
   INFO - 13:11:20:     66%|██████▌   | 66/100 [00:00<00:00, 3395.24 it/sec]
   INFO - 13:11:20:     67%|██████▋   | 67/100 [00:00<00:00, 3393.53 it/sec]
   INFO - 13:11:20:     68%|██████▊   | 68/100 [00:00<00:00, 3394.86 it/sec]
   INFO - 13:11:20:     69%|██████▉   | 69/100 [00:00<00:00, 3396.72 it/sec]
   INFO - 13:11:20:     70%|███████   | 70/100 [00:00<00:00, 3399.54 it/sec]
   INFO - 13:11:20:     71%|███████   | 71/100 [00:00<00:00, 3401.40 it/sec]
   INFO - 13:11:20:     72%|███████▏  | 72/100 [00:00<00:00, 3403.82 it/sec]
   INFO - 13:11:20:     73%|███████▎  | 73/100 [00:00<00:00, 3406.21 it/sec]
   INFO - 13:11:20:     74%|███████▍  | 74/100 [00:00<00:00, 3408.62 it/sec]
   INFO - 13:11:20:     75%|███████▌  | 75/100 [00:00<00:00, 3411.59 it/sec]
   INFO - 13:11:20:     76%|███████▌  | 76/100 [00:00<00:00, 3410.26 it/sec]
   INFO - 13:11:20:     77%|███████▋  | 77/100 [00:00<00:00, 3412.13 it/sec]
   INFO - 13:11:20:     78%|███████▊  | 78/100 [00:00<00:00, 3415.27 it/sec]
   INFO - 13:11:20:     79%|███████▉  | 79/100 [00:00<00:00, 3418.48 it/sec]
   INFO - 13:11:20:     80%|████████  | 80/100 [00:00<00:00, 3417.44 it/sec]
   INFO - 13:11:20:     81%|████████  | 81/100 [00:00<00:00, 3418.31 it/sec]
   INFO - 13:11:20:     82%|████████▏ | 82/100 [00:00<00:00, 3420.45 it/sec]
   INFO - 13:11:20:     83%|████████▎ | 83/100 [00:00<00:00, 3422.81 it/sec]
   INFO - 13:11:20:     84%|████████▍ | 84/100 [00:00<00:00, 3425.35 it/sec]
   INFO - 13:11:20:     85%|████████▌ | 85/100 [00:00<00:00, 3428.14 it/sec]
   INFO - 13:11:20:     86%|████████▌ | 86/100 [00:00<00:00, 3430.24 it/sec]
   INFO - 13:11:20:     87%|████████▋ | 87/100 [00:00<00:00, 3432.36 it/sec]
   INFO - 13:11:20:     88%|████████▊ | 88/100 [00:00<00:00, 3433.99 it/sec]
   INFO - 13:11:20:     89%|████████▉ | 89/100 [00:00<00:00, 3436.15 it/sec]
   INFO - 13:11:20:     90%|█████████ | 90/100 [00:00<00:00, 3437.70 it/sec]
   INFO - 13:11:20:     91%|█████████ | 91/100 [00:00<00:00, 3433.69 it/sec]
   INFO - 13:11:20:     92%|█████████▏| 92/100 [00:00<00:00, 3435.05 it/sec]
   INFO - 13:11:20:     93%|█████████▎| 93/100 [00:00<00:00, 3436.95 it/sec]
   INFO - 13:11:20:     94%|█████████▍| 94/100 [00:00<00:00, 3437.38 it/sec]
   INFO - 13:11:20:     95%|█████████▌| 95/100 [00:00<00:00, 3439.32 it/sec]
   INFO - 13:11:20:     96%|█████████▌| 96/100 [00:00<00:00, 3441.92 it/sec]
   INFO - 13:11:20:     97%|█████████▋| 97/100 [00:00<00:00, 3444.42 it/sec]
   INFO - 13:11:20:     98%|█████████▊| 98/100 [00:00<00:00, 3447.01 it/sec]
   INFO - 13:11:20:     99%|█████████▉| 99/100 [00:00<00:00, 3449.23 it/sec]
   INFO - 13:11:20:    100%|██████████| 100/100 [00:00<00:00, 3450.94 it/sec]
   INFO - 13:11:20: *** End Sampling execution (time: 0:00:00.031507) ***

Settings#

The PCERegressor has many options defined in the PCERegressor_Settings Pydantic model.

Degree#

model = create_regression_model("PCERegressor", training_dataset, degree=3)
model.learn()
WARNING - 13:11:20: Remove input data transformation because PCERegressor does not support transformers.

and see that this model seems to be better:

predicted_output_data_ = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - Degree(3)")
plt.grid()
plt.legend()
plt.show()
plot pce regression

Total running time of the script: (0 minutes 0.284 seconds)

Gallery generated by Sphinx-Gallery