Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import tqdm to support jupyter better #812

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/datachain/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from dvc_objects.fs.utils import remove
from fsspec.callbacks import Callback, TqdmCallback

from .progress import Tqdm

if TYPE_CHECKING:
from datachain.client import Client
from datachain.lib.file import File
Expand Down Expand Up @@ -86,9 +84,11 @@ async def download(
size = file.size
if size < 0:
size = await client.get_size(from_path, version_id=file.version)
from tqdm.auto import tqdm

cb = callback or TqdmCallback(
tqdm_kwargs={"desc": odb_fs.name(from_path), "bytes": True, "leave": False},
tqdm_cls=Tqdm,
skshetry marked this conversation as resolved.
Show resolved Hide resolved
tqdm_cls=tqdm,
size=size,
)
try:
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import sqlalchemy as sa
import yaml
from sqlalchemy import Column
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.cache import DataChainCache
from datachain.client import Client
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/client/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from urllib.parse import parse_qs, urlsplit, urlunsplit

from adlfs import AzureBlobFileSystem
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.lib.file import File

Expand Down
2 changes: 1 addition & 1 deletion src/datachain/client/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from dvc_objects.fs.system import reflink
from fsspec.asyn import get_loop, sync
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.cache import DataChainCache
from datachain.client.fileslice import FileWrapper
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/client/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from dateutil.parser import isoparse
from gcsfs import GCSFileSystem
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.lib.file import File

Expand Down
2 changes: 1 addition & 1 deletion src/datachain/client/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from botocore.exceptions import NoCredentialsError
from s3fs import S3FileSystem
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.lib.file import File

Expand Down
2 changes: 1 addition & 1 deletion src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from sqlalchemy.sql import func
from sqlalchemy.sql.expression import bindparam, cast
from sqlalchemy.sql.selectable import Select
from tqdm import tqdm
from tqdm.auto import tqdm

import datachain.sql.sqlite
from datachain.data_storage import AbstractDBMetastore, AbstractWarehouse
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from sqlalchemy import Table, case, select
from sqlalchemy.sql import func
from sqlalchemy.sql.expression import true
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.client import Client
from datachain.data_storage.schema import convert_rows_custom_column_types
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/lib/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pyarrow as pa
from fsspec.core import split_protocol
from pyarrow.dataset import CsvFileFormat, dataset
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.lib.data_model import dict_to_data_model
from datachain.lib.file import ArrowRow, File
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/lib/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from typing import TYPE_CHECKING, Any, Union

import PIL
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.lib.arrow import arrow_type_mapper
from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/listing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from sqlalchemy import Column
from sqlalchemy.sql import func
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.node import DirType, Node, NodeWithPath
from datachain.sql.functions import path as pathfunc
Expand Down
126 changes: 3 additions & 123 deletions src/datachain/progress.py
Original file line number Diff line number Diff line change
@@ -1,138 +1,16 @@
"""Manages progress bars."""

import logging
import sys
from threading import RLock
from typing import Any, ClassVar

from fsspec import Callback
from fsspec.callbacks import TqdmCallback
from tqdm import tqdm

from datachain.utils import env2bool
from tqdm.auto import tqdm

logger = logging.getLogger(__name__)
tqdm.set_lock(RLock())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this too when we can.

Maybe too much detail, but tqdm does locking by default which is multiprocess-safe. But since we overwrite this, this is only thread-safe.

dvc-data does reset it, so we'll be overriding it anyway as we are importing it someplace. But I hope we can remove it someday there. I have no idea why this was done in dvc.

Copy link
Member

@skshetry skshetry Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I have a suspicion that this is what causes progressbars to jump in dvc)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, makes sense

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dvc-data fixed in iterative/dvc-data#592.



class Tqdm(tqdm):
skshetry marked this conversation as resolved.
Show resolved Hide resolved
"""
maximum-compatibility tqdm-based progressbars
"""

BAR_FMT_DEFAULT = (
"{percentage:3.0f}% {desc}|{bar}|"
"{postfix[info]}{n_fmt}/{total_fmt}"
" [{elapsed}<{remaining}, {rate_fmt:>11}]"
)
# nested bars should have fixed bar widths to align nicely
BAR_FMT_DEFAULT_NESTED = (
"{percentage:3.0f}%|{bar:10}|{desc:{ncols_desc}.{ncols_desc}}"
"{postfix[info]}{n_fmt}/{total_fmt}"
" [{elapsed}<{remaining}, {rate_fmt:>11}]"
)
BAR_FMT_NOTOTAL = "{desc}{bar:b}|{postfix[info]}{n_fmt} [{elapsed}, {rate_fmt:>11}]"
BYTES_DEFAULTS: ClassVar[dict[str, Any]] = {
"unit": "B",
"unit_scale": True,
"unit_divisor": 1024,
"miniters": 1,
}

def __init__(
self,
iterable=None,
disable=None,
level=logging.ERROR,
desc=None,
leave=False,
bar_format=None,
bytes=False,
file=None,
total=None,
postfix=None,
**kwargs,
):
"""
bytes : shortcut for
`unit='B', unit_scale=True, unit_divisor=1024, miniters=1`
desc : persists after `close()`
level : effective logging level for determining `disable`;
used only if `disable` is unspecified
disable : If (default: None) or False,
will be determined by logging level.
May be overridden to `True` due to non-TTY status.
Skip override by specifying env var `DATACHAIN_IGNORE_ISATTY`.
kwargs : anything accepted by `tqdm.tqdm()`
"""
kwargs = kwargs.copy()
if bytes:
kwargs = self.BYTES_DEFAULTS | kwargs
else:
kwargs.setdefault("unit_scale", total > 999 if total else True)
if file is None:
file = sys.stderr
# auto-disable based on `logger.level`
if not disable:
disable = logger.getEffectiveLevel() > level
# auto-disable based on TTY
if (
not disable
and not env2bool("DATACHAIN_IGNORE_ISATTY")
and hasattr(file, "isatty")
):
disable = not file.isatty()
super().__init__(
iterable=iterable,
disable=disable,
leave=leave,
desc=desc,
bar_format="!",
lock_args=(False,),
total=total,
**kwargs,
)
self.postfix = postfix or {"info": ""}
if bar_format is None:
if self.__len__():
self.bar_format = (
self.BAR_FMT_DEFAULT_NESTED if self.pos else self.BAR_FMT_DEFAULT
)
else:
self.bar_format = self.BAR_FMT_NOTOTAL
else:
self.bar_format = bar_format
self.refresh()

def close(self):
self.postfix["info"] = ""
# remove ETA (either unknown or zero); remove completed bar
self.bar_format = self.bar_format.replace("<{remaining}", "").replace(
"|{bar:10}|", " "
)
super().close()

@property
def format_dict(self):
"""inject `ncols_desc` to fill the display width (`ncols`)"""
d = super().format_dict
ncols = d["ncols"] or 80
# assumes `bar_format` has max one of ("ncols_desc" & "ncols_info")

meter = self.format_meter( # type: ignore[call-arg]
ncols_desc=1, ncols_info=1, **d
)
ncols_left = ncols - len(meter) + 1
ncols_left = max(ncols_left, 0)
if ncols_left:
d["ncols_desc"] = d["ncols_info"] = ncols_left
else:
# work-around for zero-width description
d["ncols_desc"] = d["ncols_info"] = 1
d["prefix"] = ""
return d


class CombinedDownloadCallback(Callback):
def set_size(self, size):
# This is a no-op to prevent fsspec's .get_file() from setting the combined
Expand All @@ -148,6 +26,8 @@ def __init__(self, tqdm_kwargs=None, *args, **kwargs):
self.files_count = 0
tqdm_kwargs = tqdm_kwargs or {}
tqdm_kwargs.setdefault("postfix", {}).setdefault("files", self.files_count)
kwargs = kwargs or {}
kwargs["tqdm_cls"] = tqdm
skshetry marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(tqdm_kwargs, *args, **kwargs)

def increment_file_count(self, n: int = 1) -> None:
Expand Down
9 changes: 7 additions & 2 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from sqlalchemy.sql.expression import label
from sqlalchemy.sql.schema import TableClause
from sqlalchemy.sql.selectable import Select
from tqdm.auto import tqdm

from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
from datachain.catalog.catalog import clone_catalog_with_cache
Expand Down Expand Up @@ -366,12 +367,16 @@ def get_download_callback(suffix: str = "", **kwargs) -> CombinedDownloadCallbac


def get_processed_callback() -> Callback:
return TqdmCallback({"desc": "Processed", "unit": " rows", "leave": False})
return TqdmCallback(
{"desc": "Processed", "unit": " rows", "leave": False}, tqdm_cls=tqdm
skshetry marked this conversation as resolved.
Show resolved Hide resolved
)


def get_generated_callback(is_generator: bool = False) -> Callback:
if is_generator:
return TqdmCallback({"desc": "Generated", "unit": " rows", "leave": False})
return TqdmCallback(
{"desc": "Generated", "unit": " rows", "leave": False}, tqdm_cls=tqdm
)
return DEFAULT_CALLBACK


Expand Down
Loading