From edb9803924ab60ede3077dc72331c41d13e0c322 Mon Sep 17 00:00:00 2001 From: Rick van Hattem Date: Fri, 30 Aug 2024 00:14:16 +0200 Subject: [PATCH] many more type hinting improvements, not fully strict yet but in progress --- progressbar/algorithms.py | 3 +- progressbar/bar.py | 1 + progressbar/multi.py | 80 ++++++++++++++++----------- progressbar/shortcuts.py | 32 +++++++---- progressbar/terminal/base.py | 103 ++++++++++++++++++++++------------- progressbar/widgets.py | 34 +++++++----- pyproject.toml | 39 ++++++------- tests/test_color.py | 12 +++- 8 files changed, 188 insertions(+), 116 deletions(-) diff --git a/progressbar/algorithms.py b/progressbar/algorithms.py index cf0faf2..c0cb7a1 100644 --- a/progressbar/algorithms.py +++ b/progressbar/algorithms.py @@ -1,12 +1,13 @@ from __future__ import annotations import abc +import typing from datetime import timedelta class SmoothingAlgorithm(abc.ABC): @abc.abstractmethod - def __init__(self, **kwargs): + def __init__(self, **kwargs: typing.Any): raise NotImplementedError @abc.abstractmethod diff --git a/progressbar/bar.py b/progressbar/bar.py index c267fa2..3a5666e 100644 --- a/progressbar/bar.py +++ b/progressbar/bar.py @@ -845,6 +845,7 @@ def __iter__(self): return self def __next__(self): + value: typing.Any try: if self._iterable is None: # pragma: no cover value = self.value diff --git a/progressbar/multi.py b/progressbar/multi.py index 8900b89..948b20c 100644 --- a/progressbar/multi.py +++ b/progressbar/multi.py @@ -8,6 +8,7 @@ import threading import time import timeit +import types import typing from datetime import timedelta @@ -19,6 +20,10 @@ SortKeyFunc = typing.Callable[[bar.ProgressBar], typing.Any] +class _Update(typing.Protocol): + def __call__(self, force: bool = True, write: bool = True) -> str: ... + + class SortKey(str, enum.Enum): """ Sort keys for the MultiBar. @@ -80,7 +85,7 @@ def __init__( fd: typing.TextIO = sys.stderr, prepend_label: bool = True, append_label: bool = False, - label_format='{label:20.20} ', + label_format: str = '{label:20.20} ', initial_format: str | None = '{label:20.20} Not yet started', finished_format: str | None = None, update_interval: float = 1 / 60.0, # 60fps @@ -90,7 +95,7 @@ def __init__( sort_key: str | SortKey = SortKey.CREATED, sort_reverse: bool = True, sort_keyfunc: SortKeyFunc | None = None, - **progressbar_kwargs, + **progressbar_kwargs: typing.Any, ): self.fd = fd @@ -136,17 +141,19 @@ def __setitem__(self, key: str, bar: bar.ProgressBar): # Just in case someone is using a progressbar with a custom # constructor and forgot to call the super constructor if bar.index == -1: - bar.index = next(bar._index_counter) + bar.index = next( + bar._index_counter # pyright: ignore[reportPrivateUsage] + ) super().__setitem__(key, bar) - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: """Remove a progressbar from the multibar.""" - super().__delitem__(key) - self._finished_at.pop(key, None) - self._labeled.discard(key) + bar_: bar.ProgressBar = self.pop(key) + self._finished_at.pop(bar_, None) + self._labeled.discard(bar_) - def __getitem__(self, key): + def __getitem__(self, key: str): """Get (and create if needed) a progressbar from the multibar.""" try: return super().__getitem__(key) @@ -155,7 +162,7 @@ def __getitem__(self, key): self[key] = progress return progress - def _label_bar(self, bar: bar.ProgressBar): + def _label_bar(self, bar: bar.ProgressBar) -> None: if bar in self._labeled: # pragma: no branch return @@ -169,10 +176,12 @@ def _label_bar(self, bar: bar.ProgressBar): self._labeled.add(bar) bar.widgets.append(self.label_format.format(label=bar.label)) - def render(self, flush: bool = True, force: bool = False): + def render(self, flush: bool = True, force: bool = False) -> None: """Render the multibar to the given stream.""" - now = timeit.default_timer() - expired = now - self.remove_finished if self.remove_finished else None + now: float = timeit.default_timer() + expired: float | None = ( + now - self.remove_finished if self.remove_finished else None + ) # sourcery skip: list-comprehension output: list[str] = [] @@ -221,14 +230,18 @@ def render(self, flush: bool = True, force: bool = False): def _render_bar( self, bar_: bar.ProgressBar, - now, - expired, + now: float, + expired: float | None, ) -> typing.Iterable[str]: - def update(force=True, write=True): # pragma: no cover + def update( + force: bool = True, write: bool = True + ) -> str: # pragma: no cover self._label_bar(bar_) bar_.update(force=force) if write: - yield typing.cast(stream.LastLineStream, bar_.fd).line + return typing.cast(stream.LastLineStream, bar_.fd).line + else: + return '' if bar_.finished(): yield from self._render_finished_bar(bar_, now, expired, update) @@ -238,16 +251,16 @@ def update(force=True, write=True): # pragma: no cover else: if self.initial_format is None: bar_.start() - update() + yield update() else: yield self.initial_format.format(label=bar_.label) def _render_finished_bar( self, bar_: bar.ProgressBar, - now, - expired, - update, + now: float, + expired: float | None, + update: _Update, ) -> typing.Iterable[str]: if bar_ not in self._finished_at: self._finished_at[bar_] = now @@ -273,12 +286,12 @@ def _render_finished_bar( def print( self, - *args, - end='\n', - offset=None, - flush=True, - clear=True, - **kwargs, + *args: typing.Any, + end: str = '\n', + offset: int | None = None, + flush: bool = True, + clear: bool = True, + **kwargs: typing.Any, ): """ Print to the progressbar stream without overwriting the progressbars. @@ -316,12 +329,12 @@ def print( if flush: self.flush() - def flush(self): + def flush(self) -> None: self.fd.write(self._buffer.getvalue()) self._buffer.truncate(0) self.fd.flush() - def run(self, join=True): + def run(self, join: bool = True) -> None: """ Start the multibar render loop and run the progressbars until they have force _thread_finished. @@ -342,13 +355,13 @@ def run(self, join=True): self.render(force=True) return - def start(self): + def start(self) -> None: assert not self._thread, 'Multibar already started' self._thread_closed.set() self._thread = threading.Thread(target=self.run, args=(False,)) self._thread.start() - def join(self, timeout=None): + def join(self, timeout: float | None = None) -> None: if self._thread is not None: self._thread_closed.set() self._thread.join(timeout=timeout) @@ -369,5 +382,10 @@ def __enter__(self): self.start() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> bool | None: self.join() diff --git a/progressbar/shortcuts.py b/progressbar/shortcuts.py index edf0a5b..220c8f2 100644 --- a/progressbar/shortcuts.py +++ b/progressbar/shortcuts.py @@ -1,16 +1,25 @@ -from . import bar +from __future__ import annotations + +import typing + +from . import ( + bar, + widgets as widgets_module, +) + +T = typing.TypeVar('T') def progressbar( - iterator, - min_value: int = 0, - max_value=None, - widgets=None, - prefix=None, - suffix=None, - **kwargs, -): - progressbar = bar.ProgressBar( + iterator: typing.Iterator[T], + min_value: bar.NumberT = 0, + max_value: bar.ValueT = None, + widgets: typing.Sequence[widgets_module.WidgetBase | str] | None = None, + prefix: str | None = None, + suffix: str | None = None, + **kwargs: typing.Any, +) -> typing.Generator[T, None, None]: + progressbar_ = bar.ProgressBar( min_value=min_value, max_value=max_value, widgets=widgets, @@ -18,5 +27,4 @@ def progressbar( suffix=suffix, **kwargs, ) - - yield from progressbar(iterator) + yield from progressbar_(iterator) diff --git a/progressbar/terminal/base.py b/progressbar/terminal/base.py index 1141e52..9cba646 100644 --- a/progressbar/terminal/base.py +++ b/progressbar/terminal/base.py @@ -5,6 +5,7 @@ import colorsys import enum import threading +import typing from collections import defaultdict # Ruff is being stupid and doesn't understand `ClassVar` if it comes from the @@ -26,22 +27,24 @@ class CSI: _code: str _template = ESC + '[{args}{code}' - def __init__(self, code: str, *default_args) -> None: + def __init__(self, code: str, *default_args: typing.Any) -> None: self._code = code self._default_args = default_args - def __call__(self, *args): + def __call__(self, *args: typing.Any) -> str: return self._template.format( args=';'.join(map(str, args or self._default_args)), code=self._code, ) - def __str__(self): + def __str__(self) -> str: return self() class CSINoArg(CSI): - def __call__(self): + def __call__( # pyright: ignore[reportIncompatibleMethodOverride] + self, + ) -> str: return super().__call__() @@ -138,15 +141,15 @@ def __call__(self): # CLEAR_LINE_ALL = CLEAR_LINE.format(n=2) # Clear Line -def clear_line(n): +def clear_line(n: int): return UP(n) + CLEAR_LINE_ALL() + DOWN(n) # Report Cursor Position (CPR), response = [row;column] as row;columnR -class _CPR(str): # pragma: no cover +class _CPR(str): # pragma: no cover # pyright: ignore[reportUnusedClass] _response_lock = threading.Lock() - def __call__(self, stream) -> tuple[int, int]: + def __call__(self, stream: typing.IO[str]) -> tuple[int, int]: res: str = '' with self._response_lock: @@ -156,7 +159,7 @@ def __call__(self, stream) -> tuple[int, int]: while not res.endswith('R'): char = getch() - if char is not None: + if char: res += char res_list = res[2:-1].split(';') @@ -170,11 +173,11 @@ def __call__(self, stream) -> tuple[int, int]: return types.cast(types.Tuple[int, int], tuple(res_list)) - def row(self, stream) -> int: + def row(self, stream: typing.IO[str]) -> int: row, _ = self(stream) return row - def column(self, stream) -> int: + def column(self, stream: typing.IO[str]) -> int: _, column = self(stream) return column @@ -218,7 +221,10 @@ def from_rgb(rgb: types.Tuple[int, int, int]) -> WindowsColors: """ - def color_distance(rgb1, rgb2): + def color_distance( + rgb1: tuple[int, int, int], + rgb2: tuple[int, int, int], + ): return sum((c1 - c2) ** 2 for c1, c2 in zip(rgb1, rgb2)) return min( @@ -241,7 +247,7 @@ class WindowsColor: def __init__(self, color: Color) -> None: self.color = color - def __call__(self, text): + def __call__(self, text: str) -> str: return text ## In the future we might want to use this, but it requires direct ## printing to stdout and all of our surrounding functions expect @@ -252,8 +258,14 @@ def __call__(self, text): # windows.print_color(text, WindowsColors.from_rgb(self.color.rgb)) -class RGB(collections.namedtuple('RGB', ['red', 'green', 'blue'])): - __slots__ = () +class RGB(typing.NamedTuple): + """ + Red, Green, Blue color. + """ + + red: int + green: int + blue: int def __str__(self): return self.rgb @@ -297,7 +309,7 @@ def interpolate(self, end: RGB, step: float) -> RGB: ) -class HSL(collections.namedtuple('HSL', ['hue', 'saturation', 'lightness'])): +class HSL(typing.NamedTuple): """ Hue, Saturation, Lightness color. @@ -306,7 +318,9 @@ class HSL(collections.namedtuple('HSL', ['hue', 'saturation', 'lightness'])): """ - __slots__ = () + hue: float + saturation: float + lightness: float @classmethod def from_rgb(cls, rgb: RGB) -> HSL: @@ -333,22 +347,16 @@ def interpolate(self, end: HSL, step: float) -> HSL: class ColorBase(abc.ABC): + """ + Deprecated, `typing.NamedTuple` does not allow for multiple inheritance so + this class cannot be used with type hints. + """ + def get_color(self, value: float) -> Color: raise NotImplementedError() -class Color( - collections.namedtuple( - 'Color', - [ - 'rgb', - 'hls', - 'name', - 'xterm', - ], - ), - ColorBase, -): +class Color(typing.NamedTuple): """ Color base class. @@ -361,7 +369,10 @@ class Color( but you can be more explicitly if you wish. """ - __slots__ = () + rgb: RGB + hls: HSL + name: str | None + xterm: int | None def __call__(self, value: str) -> str: return self.fg(value) @@ -415,8 +426,11 @@ def interpolate(self, end: Color, step: float) -> Color: self.xterm if step < 0.5 else end.xterm, ) - def __str__(self): - return self.name + def __str__(self) -> str: + if self.name: + return self.name + else: + return str(self.rgb) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.name!r})' @@ -451,15 +465,15 @@ def register( name: types.Optional[str] = None, xterm: types.Optional[int] = None, ) -> Color: + if hls is None: + hls = HSL.from_rgb(rgb) + color = Color(rgb, hls, name, xterm) if name: cls.by_name[name].append(color) cls.by_lowername[name.lower()].append(color) - if hls is None: - hls = HSL.from_rgb(rgb) - cls.by_hex[rgb.hex].append(color) cls.by_rgb[rgb].append(color) cls.by_hls[hls].append(color) @@ -474,8 +488,17 @@ def interpolate(cls, color_a: Color, color_b: Color, step: float) -> Color: return color_a.interpolate(color_b, step) -class ColorGradient(ColorBase): - def __init__(self, *colors: Color, interpolate=Colors.interpolate) -> None: +class ColorGradient: + interpolate: typing.Callable[[Color, Color, float], Color] | None + colors: tuple[Color, ...] + + def __init__( + self, + *colors: Color, + interpolate: ( + typing.Callable[[Color, Color, float], Color] | None + ) = Colors.interpolate, + ) -> None: assert colors self.colors = colors self.interpolate = interpolate @@ -567,7 +590,7 @@ def apply_colors( class DummyColor: - def __call__(self, text): + def __call__(self, text: str): return text def __repr__(self) -> str: @@ -592,7 +615,11 @@ def _start_template(self): def _end_template(self): return super().__call__(self._end_code) - def __call__(self, text, *args): + def __call__( # pyright: ignore[reportIncompatibleMethodOverride] + self, + text: str, + *args: typing.Any, + ) -> str: return self._start_template + text + self._end_template diff --git a/progressbar/widgets.py b/progressbar/widgets.py index c8c3cdf..ffb201e 100644 --- a/progressbar/widgets.py +++ b/progressbar/widgets.py @@ -80,7 +80,7 @@ def wrapper(function, wrapper_): return function @functools.wraps(function) - def wrap(*args, **kwargs): + def wrap(*args: typing.Any, **kwargs: typing.Any): return wrapper_.format(function(*args, **kwargs)) return wrap @@ -123,7 +123,9 @@ class FormatWidgetMixin(abc.ABC): - percentage: Percentage as a float """ - def __init__(self, format: str, new_style: bool = False, **kwargs): + def __init__( + self, format: str, new_style: bool = False, **kwargs: typing.Any + ): self.new_style = new_style self.format = format @@ -182,7 +184,7 @@ class WidthWidgetMixin(abc.ABC): False """ - def __init__(self, min_width=None, max_width=None, **kwargs): + def __init__(self, min_width=None, max_width=None, **kwargs: typing.Any): self.min_width = min_width self.max_width = max_width @@ -350,7 +352,7 @@ class FormatLabel(FormatWidgetMixin, WidgetBase): value=('value', None), ) - def __init__(self, format: str, **kwargs): + def __init__(self, format: str, **kwargs: typing.Any): FormatWidgetMixin.__init__(self, format=format, **kwargs) WidgetBase.__init__(self, **kwargs) @@ -373,7 +375,9 @@ def __call__( class Timer(FormatLabel, TimeSensitiveWidgetBase): """WidgetBase which displays the elapsed seconds.""" - def __init__(self, format='Elapsed Time: %(elapsed)s', **kwargs): + def __init__( + self, format='Elapsed Time: %(elapsed)s', **kwargs: typing.Any + ): if '%s' in format and '%(elapsed)s' not in format: format = format.replace('%s', '%(elapsed)s') @@ -793,7 +797,7 @@ def __call__( class AdaptiveTransferSpeed(FileTransferSpeed, SamplesMixin): """Widget for showing the transfer speed based on the last X samples.""" - def __init__(self, **kwargs): + def __init__(self, **kwargs: typing.Any): FileTransferSpeed.__init__(self, **kwargs) SamplesMixin.__init__(self, **kwargs) @@ -873,7 +877,7 @@ def __call__(self, progress: ProgressBarMixinBase, data: Data, width=None): class Counter(FormatWidgetMixin, WidgetBase): """Displays the current count.""" - def __init__(self, format='%(value)d', **kwargs): + def __init__(self, format='%(value)d', **kwargs: typing.Any): FormatWidgetMixin.__init__(self, format=format, **kwargs) WidgetBase.__init__(self, format=format, **kwargs) @@ -905,7 +909,9 @@ class ColoredMixin: class Percentage(FormatWidgetMixin, ColoredMixin, WidgetBase): """Displays the current percentage as a number with a percent sign.""" - def __init__(self, format='%(percentage)3d%%', na='N/A%%', **kwargs): + def __init__( + self, format='%(percentage)3d%%', na='N/A%%', **kwargs: typing.Any + ): self.na = na FormatWidgetMixin.__init__(self, format=format, **kwargs) WidgetBase.__init__(self, format=format, **kwargs) @@ -940,7 +946,7 @@ class SimpleProgress(FormatWidgetMixin, ColoredMixin, WidgetBase): DEFAULT_FORMAT = '%(value_s)s of %(max_value_s)s' - def __init__(self, format=DEFAULT_FORMAT, **kwargs): + def __init__(self, format=DEFAULT_FORMAT, **kwargs: typing.Any): FormatWidgetMixin.__init__(self, format=format, **kwargs) WidgetBase.__init__(self, format=format, **kwargs) self.max_width_cache = dict() @@ -1170,7 +1176,7 @@ def __call__( class VariableMixin: """Mixin to display a custom user variable.""" - def __init__(self, name, **kwargs): + def __init__(self, name, **kwargs: typing.Any): if not isinstance(name, str): raise TypeError('Variable(): argument must be a string') if len(name.split()) > 1: @@ -1189,7 +1195,7 @@ class MultiRangeBar(Bar, VariableMixin): [['Symbol1', amount1], ['Symbol2', amount2], ...] """ - def __init__(self, name, markers, **kwargs): + def __init__(self, name, markers, **kwargs: typing.Any): VariableMixin.__init__(self, name) Bar.__init__(self, **kwargs) self.markers = [string_or_lambda(marker) for marker in markers] @@ -1359,7 +1365,7 @@ def __call__( class FormatLabelBar(FormatLabel, Bar): """A bar which has a formatted label in the center.""" - def __init__(self, format, **kwargs): + def __init__(self, format, **kwargs: typing.Any): FormatLabel.__init__(self, format, **kwargs) Bar.__init__(self, **kwargs) @@ -1399,7 +1405,9 @@ class PercentageLabelBar(Percentage, FormatLabelBar): # %3d adds an extra space that makes it look off-center # %2d keeps the label somewhat consistently in-place - def __init__(self, format='%(percentage)2d%%', na='N/A%%', **kwargs): + def __init__( + self, format='%(percentage)2d%%', na='N/A%%', **kwargs: typing.Any + ): Percentage.__init__(self, format, na=na, **kwargs) FormatLabelBar.__init__(self, format, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index c569a2a..c9ee86e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,32 +182,33 @@ exclude_lines = [ 'if types.TYPE_CHECKING:', '@typing.overload', 'if os.name == .nt.:', + 'typing.Protocol', ] [tool.pyright] include= ['progressbar'] -exclude= ['examples'] +exclude= ['examples', '.tox'] ignore= ['docs'] -#strict = [ -# 'progressbar/algorithms.py', -# 'progressbar/env.py', +strict = [ + 'progressbar/algorithms.py', + 'progressbar/env.py', # 'progressbar/shortcuts.py', -## 'progressbar/multi.py', -## 'progressbar/__init__.py', -# 'progressbar/terminal/__init__.py', -## 'progressbar/terminal/stream.py', -# 'progressbar/terminal/os_specific/__init__.py', + 'progressbar/multi.py', + 'progressbar/__init__.py', + 'progressbar/terminal/__init__.py', + 'progressbar/terminal/stream.py', + 'progressbar/terminal/os_specific/__init__.py', # 'progressbar/terminal/os_specific/posix.py', -## 'progressbar/terminal/os_specific/windows.py', -## 'progressbar/terminal/base.py', -## 'progressbar/terminal/colors.py', -## 'progressbar/widgets.py', -## 'progressbar/utils.py', -# 'progressbar/__about__.py', -## 'progressbar/bar.py', -# 'progressbar/__main__.py', -# 'progressbar/base.py', -#] +# 'progressbar/terminal/os_specific/windows.py', + 'progressbar/terminal/base.py', + 'progressbar/terminal/colors.py', +# 'progressbar/widgets.py', +# 'progressbar/utils.py', + 'progressbar/__about__.py', +# 'progressbar/bar.py', + 'progressbar/__main__.py', + 'progressbar/base.py', +] reportIncompatibleMethodOverride = false reportUnnecessaryIsInstance = false diff --git a/tests/test_color.py b/tests/test_color.py index 90b9b1b..4a368af 100644 --- a/tests/test_color.py +++ b/tests/test_color.py @@ -8,7 +8,7 @@ import progressbar import progressbar.terminal from progressbar import env, terminal, widgets -from progressbar.terminal import Colors, apply_colors, colors +from progressbar.terminal import Color, Colors, apply_colors, colors ENVIRONMENT_VARIABLES = [ 'PROGRESSBAR_ENABLE_COLORS', @@ -227,10 +227,18 @@ def test_colors(monkeypatch) -> None: assert color.fg assert color.bg - assert str(color) assert str(rgb) assert color('test') + color_no_name = Color( + rgb=color.rgb, + hls=color.hls, + name=None, + xterm=color.xterm, + ) + # Test without name + assert str(color_no_name) != str(color) + def test_color() -> None: color = colors.red