Source code for gemseo.algos.progress_bar
# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
"""Progress bar."""
from __future__ import annotations
import io
import logging
import string
from typing import Any
from typing import Callable
import tqdm
from tqdm.utils import _unicode
from tqdm.utils import disp_len
LOGGER = logging.getLogger(__name__)
[docs]class TqdmToLogger(io.StringIO):
"""Redirect tqdm output to the gemseo logger."""
[docs] def write(self, buf: str) -> None:
"""Write buffer."""
buf = buf.strip(string.whitespace)
if buf:
LOGGER.info(buf)
[docs]class ProgressBar(tqdm.tqdm):
"""Extend tqdm progress bar with better time units.
Use hour, day or week for slower processes.
"""
@staticmethod
def __convert_rate(total, elapsed):
rps = float(total) / elapsed
if rps >= 0:
rate = rps
unit = "sec"
rpm = rps * 60.0
if rpm < 60.0:
rate = rpm
unit = "min"
rph = rpm * 60.0
if rph < 60.0:
rate = rph
unit = "hour"
rpd = rph * 24.0
if rpd < 24.0:
rate = rpd
unit = "day"
return rate, f" it/{unit}"
[docs] def status_printer(
self, file: io.TextIOWrapper | io.StringIO
) -> Callable[[str], None]:
"""Overload the status_printer method to avoid the use of closures.
Args:
file: Specifies where to output the progress messages.
Returns:
The function to print the status in the progress bar.
"""
self._last_len = [0]
return self._print_status
def _print_status(self, s: str) -> None:
len_s = disp_len(s)
self.fp.write(
_unicode("\r{}{}".format(s, (" " * max(self._last_len[0] - len_s, 0))))
)
self.fp.flush()
self._last_len[0] = len_s
def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
# A file-like stream cannot be pickled.
del state["fp"]
return state
def __setstate__(self, state) -> None:
self.__dict__.update(state)
# Set back the file-like stream to its state as done in tqdm.__init__.
self.fp = tqdm.utils.DisableOnWriteError(TqdmToLogger(), tqdm_instance=self)