Source code for gemseo.mlearning.linear_model_fitting.elastic_net_settings
# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
"""Settings for the scikit-learn elastic net algorithm."""
from __future__ import annotations
from typing import ClassVar
from typing import Literal
from numpy import ndarray # noqa: TC002
from numpy.random import RandomState # noqa: TC002
from pydantic import Field
from pydantic import NonNegativeFloat
from pydantic import PositiveFloat
from pydantic import PositiveInt
from gemseo.mlearning.linear_model_fitting.base_linear_model_fitter_settings import (
BaseLinearModelFitter_Settings,
)
from gemseo.settings.base_settings import BaseSettings
class _ElasticNetMixin(BaseSettings):
"""Mixin for defining the settings of the scikit-learn elasticnet algorithm."""
copy_X: bool = Field( # noqa: N815
default=True,
description="""If ``True``, input data will be copied;
else, it may be overwritten""",
)
max_iter: PositiveInt = Field(
default=1000, description="""The maximum number of iterations."""
)
positive: bool = Field(
default=False,
description="""When set to ``True``, forces the coefficients to be positive.""",
)
precompute: bool | ndarray = Field(
default=False,
description="""Whether to use a precomputed Gram matrix
to speed up calculations.
The Gram matrix can also be passed as ``precompute`` value.
For sparse input this option is always ``False`` to preserve sparsity.""",
)
random_state: int | RandomState | None = Field(
default=None,
description="""The seed of the pseudo random number generator
that selects a random feature to update.
Used when ``selection == "random"``.
Pass an int for reproducible output across multiple function calls.""",
)
selection: Literal["cyclic", "random"] = Field(
default="cyclic",
description="""If set to "random",
a random coefficient is updated every iteration
rather than looping over features sequentially by default.
This (setting to "random") often leads to significantly faster convergence
especially when ``tol`` is higher than 1e-4.""",
)
tol: PositiveFloat = Field(
default=1e-4,
description="""The tolerance for the optimization:
if the updates are smaller than ``tol``,
the optimization code checks the dual gap for optimality
and continues until it is smaller than ``tol``.""",
)
[docs]
class ElasticNet_Settings(_ElasticNetMixin, BaseLinearModelFitter_Settings): # noqa: N801
"""Settings for the scikit-learn elastic net algorithm."""
_TARGET_CLASS_NAME: ClassVar[str] = "ElasticNet"
alpha: NonNegativeFloat = Field(
default=1.0,
description=r"""The constant :math:`\alpha` that multiplies the L1 and 2 terms,
controlling regularization strength.""",
)
l1_ratio: NonNegativeFloat = Field(
default=0.5,
le=1.0,
description=r"""The ElasticNet mixing parameter :math:`\rho`.
For ``l1_ratio = 0``, the penalty is an L2 penalty.
For ``l1_ratio = 1``, it is an L1 penalty.
For ``0 < l1_ratio < 1``, the penalty is a combination of L1 and L2.""",
)
warm_start: bool = Field(
default=False,
description="""When set to ``True``,
reuse the solution of the previous call to fit as initialization,
otherwise, just erase the previous solution.""",
)