{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Mixture of experts with PCA on Burgers dataset\n\nIn this demo, we apply a mixture of experts regression model to the Burgers\ndataset. In order to reduce the output dimension, we apply a PCA to the\noutputs.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Imports\nImport from standard libraries and |g|.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from __future__ import annotations\n\nimport matplotlib.pyplot as plt\nfrom gemseo.api import configure_logger\nfrom gemseo.api import load_dataset\nfrom gemseo.mlearning.api import create_regression_model\nfrom matplotlib.lines import Line2D\nfrom numpy import nonzero\n\nconfigure_logger()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Load dataset (Burgers)\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "n_samples = 50\ndataset = load_dataset(\"BurgersDataset\", n_samples=n_samples)\ninputs = dataset.get_data_by_group(dataset.INPUT_GROUP)\noutputs = dataset.get_data_by_group(dataset.OUTPUT_GROUP)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Mixture of experts (MoE)\nIn this section we load a mixture of experts regression model through the\nmachine learning API, using clustering, classification and regression models.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Mixture of experts model\nWe construct the MoE model using the predefined parameters, and fit the model\nto the dataset through the learn() method.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = create_regression_model(\"MOERegressor\", dataset)\nmodel.set_clusterer(\"KMeans\", n_clusters=2, transformer={\"outputs\": \"JamesonSensor\"})\nmodel.set_classifier(\"KNNClassifier\", n_neighbors=3)\nmodel.set_regressor(\n    \"GaussianProcessRegressor\", transformer={\"outputs\": (\"PCA\", {\"n_components\": 20})}\n)\n\nmodel.learn()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Make predictions\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "predictions = model.predict(inputs)\nlocal_pred_0 = model.predict_local_model(inputs, 0)\nlocal_pred_1 = model.predict_local_model(inputs, 1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Plot clusters\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "for i in nonzero(model.clusterer.labels == 0)[0]:\n    plt.plot(outputs[i], color=\"r\")\nfor i in nonzero(model.clusterer.labels == 1)[0]:\n    plt.plot(outputs[i], color=\"b\")\nplt.legend(\n    [Line2D([0], [0], color=\"r\"), Line2D([0], [0], color=\"b\")],\n    [\"Cluster 0\", \"Cluster 1\"],\n)\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Plot predictions\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def lines(i):\n    return (0, (i + 3, 1, 1, 1))\n\n\nfor i, pred in enumerate(predictions):\n    color = \"b\"\n    if model.labels[i] == 0:\n        color = \"r\"\n    plt.plot(pred, color=color, linestyle=lines(i))\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Plot local models\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "plt.subplot(121)\nfor i, pred in enumerate(local_pred_0):\n    plt.plot(pred, color=\"r\", linestyle=lines(i))\nplt.subplot(122)\nfor i, pred in enumerate(local_pred_1):\n    plt.plot(pred, color=\"b\", linestyle=lines(i))\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Plot selected predictions and exact curves\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "for i in [\n    0,\n    int(dataset.n_samples / 4),\n    int(dataset.n_samples * 2 / 4),\n    int(dataset.n_samples * 3 / 4),\n    -1,\n]:\n    plt.plot(outputs[i], color=\"r\")\n    plt.plot(predictions[i], color=\"b\", linestyle=\":\")\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Plot components\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "local_models = model.regress_models\nplt.subplot(121)\nplt.plot(local_models[0].transformer[\"outputs\"].components)\nplt.title(\"1st local model\")\nplt.subplot(122)\nplt.plot(local_models[1].transformer[\"outputs\"].components)\nplt.title(\"2nd local model\")\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}