Source code for gemseo.post.scatter_mat
# -*- coding: utf-8 -*-
# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
# Contributors:
# INITIAL AUTHORS - API and implementation and/or documentation
# :author: Francois Gallard
# :author: Damien Guenot
# OTHER AUTHORS - MACROSCOPIC CHANGES
"""A scatter plot matrix to display optimization history."""
from __future__ import division, unicode_literals
import logging
from typing import Sequence
from matplotlib import pyplot
from pandas.core.frame import DataFrame
from gemseo.post.opt_post_processor import OptPostProcessor
try:
from pandas.tools.plotting import scatter_matrix
except ImportError:
from pandas.plotting import scatter_matrix
LOGGER = logging.getLogger(__name__)
[docs]class ScatterPlotMatrix(OptPostProcessor):
"""Scatter plot matrix among design variables, output functions and constraints.
The list of variable names has to be passed as arguments of the plot method.
x- and y- figure sizes can be changed in option.
"""
def _plot(
self,
variables_list, # type: Sequence[str]
figsize_x=10, # type: int
figsize_y=10, # type: int
): # type: (...) -> None
"""
Args:
variables_list: The functions names or design variables to plot.
If the list is empty,
plot all design variables.
figsize_x: The size of the figure in horizontal direction (inches).
figsize_y: The size of the figure in vertical direction (inches).
"""
add_dv = False
all_funcs = self.opt_problem.get_all_functions_names()
all_dv_names = self.opt_problem.design_space.variables_names
variables_list.sort()
if not variables_list:
# In this case, plot all design variables, no functions.
vals = self.database.get_x_history()
# This section creates readable labels for design variables
# i.e. toto_0, toto_1 if toto is a variable with 2 components
x_labels = self.__get_design_var_labels(all_dv_names)
else:
design_variables = []
for func in list(variables_list):
if func not in all_funcs and func not in all_dv_names:
min_f = "-{}".format(func) == self.opt_problem.objective.name
if min_f and not self.opt_problem.minimize_objective:
variables_list[variables_list.index(func)] = "-{}".format(func)
variables_list.sort()
else:
msg = (
"Cannot build scatter plot matrix, "
"Function {} is neither among"
" optimization problem functions : {}"
" nor design variables : {}".format(
func, all_funcs, all_dv_names
)
)
raise ValueError(msg)
if func in self.opt_problem.design_space.variables_names:
# if given function is a design variable, then remove it
add_dv = True
variables_list.remove(func)
design_variables.append(func)
if not design_variables:
design_variables = None
if add_dv:
# Sort the design variables to be consistent with GEMSEO.
design_variables = sorted(
set(all_dv_names) & set(design_variables),
key=all_dv_names.index,
)
# This section creates readable labels for design variables
# and functions i.e. toto_0, toto_1 if toto is a variable
# with 2 components
dv_labels = self.__get_design_var_labels(design_variables)
if variables_list:
_, func_labels, _ = self.database.get_history_array(
functions=variables_list,
design_variables_names=None,
add_dv=False,
)
else:
func_labels = []
# vname contains function names + condensed variable names
# i.e. "toto" even if toto has 2 components or more
vname = variables_list + design_variables
# x_labels contains function names + readable variable names
x_labels = func_labels + dv_labels
else:
# In this case we are only plotting functions.
# Functions have unique names, so x_labels and
# vname are equal.
vname = variables_list
_, x_labels, _ = self.database.get_history_array(
functions=variables_list,
design_variables_names=None,
add_dv=False,
)
x_labels.sort()
dataset = self.opt_problem.export_to_dataset("OptimizationProblem")
vals = dataset.get_data_by_names(vname, False)
# Next line is a trick for a bug workaround in numpy/matplotlib
# https://stackoverflow.com/questions/39180873/
# pandas-dataframe-valueerror-num-must-be-1-num-0-not-1
vals = (list(x) for x in vals)
frame = DataFrame(vals, columns=x_labels)
scatter_matrix(frame, alpha=1.0, figsize=(figsize_x, figsize_y), diagonal="kde")
fig = pyplot.gcf()
fig.tight_layout()
self._add_figure(fig)
def __get_design_var_labels(self, des_vars):
"""Create labels for design variables.
Args:
des_vars (list(str)): The design variables to get its labels.
Returns:
list(str): The labels for the design variables.
"""
dv_names = []
for d_v in des_vars:
dv_size = self.opt_problem.design_space.variables_sizes[d_v]
if dv_size == 1:
dv_names.append(d_v)
else:
for k in range(dv_size):
dv_names.append("{}_{}".format(d_v, k))
return dv_names