Source code for gemseo.core.factory

# -*- coding: utf-8 -*-
# Copyright 2021 IRT Saint Exupéry,
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

# Contributors:
#    INITIAL AUTHORS - initial API and implementation and/or
#                      initial documentation
#        :author:  Francois Gallard
Factory base class
from __future__ import absolute_import, division, print_function, unicode_literals

import importlib
import os
import pkgutil
import sys

from future import standard_library
from future.utils import with_metaclass

from gemseo.core.json_grammar import JSONGrammar
from gemseo.third_party.prettytable import PrettyTable
from gemseo.utils.singleton import SingleInstancePerAttributeEq
from gemseo.utils.source_parsing import SourceParsing


from gemseo import LOGGER

[docs]class Factory(with_metaclass(SingleInstancePerAttributeEq, object)): """ Factory to create extensions that are known to |g|: can be a MDODiscipline, MDOFormulation... Depending on the subclass Three types of directories are scanned : - the environment variable "GEMSEO_PATH" may contain the list of directories to scan - internal_modules_paths (such as gemseo.problems...) - a directory list may be passed to the factory """ # Name of the environment variable to search for classes GEMSEO_PATH = "GEMSEO_PATH" GEMS_PATH = "GEMS_PATH" # Allowed prefix for naming a plugin in sys.path PLUGIN_PREFIX = "gemseo_" def __init__(self, base_class=None, internal_modules_paths=None): """Initializes the factory. Scans the directories to search for subclasses of MDODiscipline. Searches in "GEMSEO_PATH", "GEMS_PATH", and gemseo.problems :param base_class: class to search in the modules (MDOFormulation, MDODiscipline...) depending on the subclass :param internal_modules_paths: import paths (such as gemseo.problems) which are already imported :param name: name of the factory to print when configuration is printed :param possible_plugin_names: tuple of plugins packages names to be scanned if they can be imported. The last plugin name has the priority. For instance, if the same class MDAJacobi exists in gemseo.mda, gemseo_plugins.mda and gemseo_private.mda, the used one will be gemseo_private.mda """ if not isinstance(base_class, type): raise TypeError("Class to search must be a class!") self.base_class = base_class self.internal_modules_paths = internal_modules_paths or [] self.failed_imports = {} self.__names_to_classes = {} self.update() def _update_path_from_env_variable(self, env_variable): """Update the classes that can be instanciated from a factory from an environment variable. param env_variable: name of the environment variable """ g_path = os.environ.get(env_variable) if g_path is None: return if ":" in g_path: paths = g_path.split(":") else: paths = [g_path] # temporary make the gemseo paths visible to the import machinery for path in paths: sys.path.insert(0, path) for _, mod_name, _ in pkgutil.iter_modules(path=paths): self.__import_modules_from(mod_name) for path in paths: sys.path.pop(0)
[docs] def update(self): """Update the classes that can be created by the factory. In order, scan in the internal modules, then in plugins, then in GEMSEO_PATH. """ # Scan internal packages for mod_name in self.internal_modules_paths: self.__import_modules_from(mod_name) # Scan plugins packages for _, mod_name, _ in pkgutil.iter_modules(): if mod_name.startswith(self.PLUGIN_PREFIX): self.__import_modules_from(mod_name) gems_path = os.environ.get(self.GEMS_PATH) if gems_path is not None: msg = """GEMS is now named GEMSEO. The GEMS_PATH environment variable is now deprecated and it is strongly recommended to use the GEMSEO_PATH environment variable instead to register your GEMSEO plugins.""" LOGGER.warn(msg) # Scan environment variable paths env_variables = [self.GEMSEO_PATH, self.GEMS_PATH] for env_variable in env_variables: self._update_path_from_env_variable(env_variable) self.__names_to_classes = self.__get_sub_classes(self.base_class)
def __log_import_failure(self, pkg_name): """Log import failures.""" LOGGER.debug("Failed to import package %s", pkg_name) self.failed_imports[pkg_name] = "" def __import_modules_from(self, pkg_name): """Import modules from a package.""" pkg = importlib.import_module(pkg_name) if not hasattr(pkg, "__path__"): # not a package so no more module to import return for _, mod_name, _ in pkgutil.walk_packages( pkg.__path__, pkg.__name__ + ".", self.__log_import_failure ): try: importlib.import_module(mod_name) except Exception as err: # pylint: disable=(broad-except LOGGER.debug("Failed to import module: %s, %s", mod_name, err) self.failed_imports[mod_name] = err def __get_sub_classes(self, cls): """Return all the sub classes of cls. The class names are unique, the last imported is kept when more than one class have the same name. """ all_sub_classes = {} for sub_class in cls.__subclasses__(): sub_classes = {sub_class.__name__: sub_class} sub_classes.update(self.__get_sub_classes(sub_class)) for cls_name, _cls in sub_classes.items(): all_sub_classes[cls_name] = _cls return all_sub_classes @property def classes(self): """Return the available classes. :returns : the list of classes names """ return sorted(self.__names_to_classes.keys()) # TODO: rename to has_class
[docs] def is_available(self, name): """Return whether a class is available. :param name : name of the class :returns: True if the class is available """ return name in self.__names_to_classes
[docs] def get_class(self, name): """Return the class from its name. :param name : name of the class """ try: return self.__names_to_classes[name] except KeyError: msg = "Class {} is not available!\nAvailable ones are: {}".format( name, ", ".join(sorted(self.__names_to_classes.keys())) ) raise ImportError(msg)
[docs] def create(self, class_name, **options): """Return an instance with given class name. :param class_name : name of the class :parma options: options to be passed to the constructor """ cls = self.get_class(class_name) try: return cls(**options) except TypeError: # TODO: raise an error with message and let the callers handle # logging LOGGER.error( "Failed to create class %s with arguments %s", class_name, options ) raise
[docs] def get_options_doc(self, name): """Return the options documentation for the given class name. :param name: name of the class :returns: the dict of option name: option documentation """ cls = self.get_class(name) return SourceParsing.get_options_doc(cls.__init__)
[docs] def get_default_options_values(self, name): """Return the options default values for the given class name. Only addresses kwargs :param name : name of the class :returns: the dict option name: option default value """ cls = self.get_class(name) return SourceParsing.get_default_options_values(cls)
[docs] def get_options_grammar(self, name, write_schema=False, schema_file=None): """Return the options grammar for a class. Attempts to generate a JSONGrammar from the arguments of the __init__ method of the class :param name: name of the class :param schema_file: the output json file path. If None: input.json or output.json depending on gramamr type. (Default value = None) :param write_schema: if True, writes the schema files (Default value = False) :returns: the json grammar for options """ args_dict = self.get_default_options_values(name) opts_doc = self.get_options_doc(name) opts_doc = {k: v for k, v in opts_doc.items() if k in args_dict} gramm = JSONGrammar(name) gramm.initialize_from_base_dict( args_dict, schema_file=schema_file, write_schema=write_schema, description_dict=opts_doc, ) # Remove None args from required sch_dict = gramm.schema.to_dict() required = sch_dict["required"] has_changed = False for opt, val in args_dict.items(): if val is None and opt in required: required.remove(opt) has_changed = True if has_changed: gramm = JSONGrammar(name, schema=sch_dict) return gramm
[docs] def get_sub_options_grammar(self, class_name, **options): """Return the JSONGrammar of the sub options of a class. :param class_name: name of the class :param options: options to be passed to the class required to deduce the sub options """ cls = self.get_class(class_name) return cls.get_sub_options_grammar(**options)
[docs] def get_default_sub_options_values(self, class_name, **options): """Return the default values of the sub options of a class. :param class_name: name of the class :param options: options to be passed to the class required to deduce the sub options """ cls = self.get_class(class_name) return cls.get_default_sub_options_values(**options)
def __str__(self): """Return the representation of a factory. Gives the configuration with the successfully loaded modules and failed imports with the reason. """ table = PrettyTable( ["Module", "Is available ?", "Purpose or error message"], title=self.base_class.__name__, min_table_width=120, max_table_width=120, ) row_dict = {} for cls in self.__names_to_classes.values(): msg = "" try: msgs = cls.__doc__.split("\n") while msgs and msg == "": msg = msgs[0] del msgs[0] except Exception as err: # pylint: disable=(broad-except pass key = cls.__name__ row_dict[key] = [key, "Yes", msg] for key, err in self.failed_imports.items(): row_dict[key] = [key, "No", str(err)] # Take them all and then sort them for pretty printing for key in sorted(row_dict.keys()): table.add_row(row_dict[key]) return table.get_string()