Skip to content

Commit

Permalink
many more type hinting improvements, not fully strict yet but in prog…
Browse files Browse the repository at this point in the history
…ress
  • Loading branch information
wolph committed Aug 29, 2024
1 parent 4b548b1 commit edb9803
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 116 deletions.
3 changes: 2 additions & 1 deletion progressbar/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import abc
import typing
from datetime import timedelta


class SmoothingAlgorithm(abc.ABC):
@abc.abstractmethod
def __init__(self, **kwargs):
def __init__(self, **kwargs: typing.Any):
raise NotImplementedError

@abc.abstractmethod
Expand Down
1 change: 1 addition & 0 deletions progressbar/bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,7 @@ def __iter__(self):
return self

def __next__(self):
value: typing.Any
try:
if self._iterable is None: # pragma: no cover
value = self.value
Expand Down
80 changes: 49 additions & 31 deletions progressbar/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import threading
import time
import timeit
import types
import typing
from datetime import timedelta

Expand All @@ -19,6 +20,10 @@
SortKeyFunc = typing.Callable[[bar.ProgressBar], typing.Any]


class _Update(typing.Protocol):
def __call__(self, force: bool = True, write: bool = True) -> str: ...


class SortKey(str, enum.Enum):
"""
Sort keys for the MultiBar.
Expand Down Expand Up @@ -80,7 +85,7 @@ def __init__(
fd: typing.TextIO = sys.stderr,
prepend_label: bool = True,
append_label: bool = False,
label_format='{label:20.20} ',
label_format: str = '{label:20.20} ',
initial_format: str | None = '{label:20.20} Not yet started',
finished_format: str | None = None,
update_interval: float = 1 / 60.0, # 60fps
Expand All @@ -90,7 +95,7 @@ def __init__(
sort_key: str | SortKey = SortKey.CREATED,
sort_reverse: bool = True,
sort_keyfunc: SortKeyFunc | None = None,
**progressbar_kwargs,
**progressbar_kwargs: typing.Any,
):
self.fd = fd

Expand Down Expand Up @@ -136,17 +141,19 @@ def __setitem__(self, key: str, bar: bar.ProgressBar):
# Just in case someone is using a progressbar with a custom
# constructor and forgot to call the super constructor
if bar.index == -1:
bar.index = next(bar._index_counter)
bar.index = next(
bar._index_counter # pyright: ignore[reportPrivateUsage]
)

super().__setitem__(key, bar)

def __delitem__(self, key):
def __delitem__(self, key: str) -> None:
"""Remove a progressbar from the multibar."""
super().__delitem__(key)
self._finished_at.pop(key, None)
self._labeled.discard(key)
bar_: bar.ProgressBar = self.pop(key)
self._finished_at.pop(bar_, None)
self._labeled.discard(bar_)

def __getitem__(self, key):
def __getitem__(self, key: str):
"""Get (and create if needed) a progressbar from the multibar."""
try:
return super().__getitem__(key)
Expand All @@ -155,7 +162,7 @@ def __getitem__(self, key):
self[key] = progress
return progress

def _label_bar(self, bar: bar.ProgressBar):
def _label_bar(self, bar: bar.ProgressBar) -> None:
if bar in self._labeled: # pragma: no branch
return

Expand All @@ -169,10 +176,12 @@ def _label_bar(self, bar: bar.ProgressBar):
self._labeled.add(bar)
bar.widgets.append(self.label_format.format(label=bar.label))

def render(self, flush: bool = True, force: bool = False):
def render(self, flush: bool = True, force: bool = False) -> None:
"""Render the multibar to the given stream."""
now = timeit.default_timer()
expired = now - self.remove_finished if self.remove_finished else None
now: float = timeit.default_timer()
expired: float | None = (
now - self.remove_finished if self.remove_finished else None
)

# sourcery skip: list-comprehension
output: list[str] = []
Expand Down Expand Up @@ -221,14 +230,18 @@ def render(self, flush: bool = True, force: bool = False):
def _render_bar(
self,
bar_: bar.ProgressBar,
now,
expired,
now: float,
expired: float | None,
) -> typing.Iterable[str]:
def update(force=True, write=True): # pragma: no cover
def update(
force: bool = True, write: bool = True
) -> str: # pragma: no cover
self._label_bar(bar_)
bar_.update(force=force)
if write:
yield typing.cast(stream.LastLineStream, bar_.fd).line
return typing.cast(stream.LastLineStream, bar_.fd).line
else:
return ''

if bar_.finished():
yield from self._render_finished_bar(bar_, now, expired, update)
Expand All @@ -238,16 +251,16 @@ def update(force=True, write=True): # pragma: no cover
else:
if self.initial_format is None:
bar_.start()
update()
yield update()
else:
yield self.initial_format.format(label=bar_.label)

def _render_finished_bar(
self,
bar_: bar.ProgressBar,
now,
expired,
update,
now: float,
expired: float | None,
update: _Update,
) -> typing.Iterable[str]:
if bar_ not in self._finished_at:
self._finished_at[bar_] = now
Expand All @@ -273,12 +286,12 @@ def _render_finished_bar(

def print(
self,
*args,
end='\n',
offset=None,
flush=True,
clear=True,
**kwargs,
*args: typing.Any,
end: str = '\n',
offset: int | None = None,
flush: bool = True,
clear: bool = True,
**kwargs: typing.Any,
):
"""
Print to the progressbar stream without overwriting the progressbars.
Expand Down Expand Up @@ -316,12 +329,12 @@ def print(
if flush:
self.flush()

def flush(self):
def flush(self) -> None:
self.fd.write(self._buffer.getvalue())
self._buffer.truncate(0)
self.fd.flush()

def run(self, join=True):
def run(self, join: bool = True) -> None:
"""
Start the multibar render loop and run the progressbars until they
have force _thread_finished.
Expand All @@ -342,13 +355,13 @@ def run(self, join=True):
self.render(force=True)
return

def start(self):
def start(self) -> None:
assert not self._thread, 'Multibar already started'
self._thread_closed.set()
self._thread = threading.Thread(target=self.run, args=(False,))
self._thread.start()

def join(self, timeout=None):
def join(self, timeout: float | None = None) -> None:
if self._thread is not None:
self._thread_closed.set()
self._thread.join(timeout=timeout)
Expand All @@ -369,5 +382,10 @@ def __enter__(self):
self.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
self.join()
32 changes: 20 additions & 12 deletions progressbar/shortcuts.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
from . import bar
from __future__ import annotations

import typing

from . import (
bar,
widgets as widgets_module,
)

T = typing.TypeVar('T')


def progressbar(
iterator,
min_value: int = 0,
max_value=None,
widgets=None,
prefix=None,
suffix=None,
**kwargs,
):
progressbar = bar.ProgressBar(
iterator: typing.Iterator[T],
min_value: bar.NumberT = 0,
max_value: bar.ValueT = None,
widgets: typing.Sequence[widgets_module.WidgetBase | str] | None = None,
prefix: str | None = None,
suffix: str | None = None,
**kwargs: typing.Any,
) -> typing.Generator[T, None, None]:
progressbar_ = bar.ProgressBar(
min_value=min_value,
max_value=max_value,
widgets=widgets,
prefix=prefix,
suffix=suffix,
**kwargs,
)

yield from progressbar(iterator)
yield from progressbar_(iterator)
Loading

0 comments on commit edb9803

Please sign in to comment.