Skip to content

Commit

Permalink
fix(tqdm): import tqdm to support jupyter
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein committed Jan 13, 2025
1 parent 12da0c9 commit 2dd0683
Show file tree
Hide file tree
Showing 13 changed files with 25 additions and 17 deletions.
6 changes: 3 additions & 3 deletions src/datachain/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from dvc_objects.fs.utils import remove
from fsspec.callbacks import Callback, TqdmCallback

from .progress import Tqdm

if TYPE_CHECKING:
from datachain.client import Client
from datachain.lib.file import File
Expand Down Expand Up @@ -86,9 +84,11 @@ async def download(
size = file.size
if size < 0:
size = await client.get_size(from_path, version_id=file.version)
from tqdm.auto import tqdm

cb = callback or TqdmCallback(
tqdm_kwargs={"desc": odb_fs.name(from_path), "bytes": True, "leave": False},
tqdm_cls=Tqdm,
tqdm_cls=tqdm,
size=size,
)
try:
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import sqlalchemy as sa
import yaml
from sqlalchemy import Column
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.cache import DataChainCache
from datachain.client import Client
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/client/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from urllib.parse import parse_qs, urlsplit, urlunsplit

from adlfs import AzureBlobFileSystem
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.lib.file import File

Expand Down
2 changes: 1 addition & 1 deletion src/datachain/client/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from dvc_objects.fs.system import reflink
from fsspec.asyn import get_loop, sync
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.cache import DataChainCache
from datachain.client.fileslice import FileWrapper
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/client/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from dateutil.parser import isoparse
from gcsfs import GCSFileSystem
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.lib.file import File

Expand Down
2 changes: 1 addition & 1 deletion src/datachain/client/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from botocore.exceptions import NoCredentialsError
from s3fs import S3FileSystem
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.lib.file import File

Expand Down
2 changes: 1 addition & 1 deletion src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from sqlalchemy.sql import func
from sqlalchemy.sql.expression import bindparam, cast
from sqlalchemy.sql.selectable import Select
from tqdm import tqdm
from tqdm.auto import tqdm

import datachain.sql.sqlite
from datachain.data_storage import AbstractDBMetastore, AbstractWarehouse
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from sqlalchemy import Table, case, select
from sqlalchemy.sql import func
from sqlalchemy.sql.expression import true
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.client import Client
from datachain.data_storage.schema import convert_rows_custom_column_types
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/lib/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pyarrow as pa
from fsspec.core import split_protocol
from pyarrow.dataset import CsvFileFormat, dataset
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.lib.data_model import dict_to_data_model
from datachain.lib.file import ArrowRow, File
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/lib/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from typing import TYPE_CHECKING, Any, Union

import PIL
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.lib.arrow import arrow_type_mapper
from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/listing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from sqlalchemy import Column
from sqlalchemy.sql import func
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.node import DirType, Node, NodeWithPath
from datachain.sql.functions import path as pathfunc
Expand Down
7 changes: 5 additions & 2 deletions src/datachain/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@

from fsspec import Callback
from fsspec.callbacks import TqdmCallback
from tqdm import tqdm
from tqdm.auto import tqdm
from tqdm.auto import tqdm as ntqdm

from datachain.utils import env2bool

logger = logging.getLogger(__name__)
tqdm.set_lock(RLock())


class Tqdm(tqdm):
class Tqdm(ntqdm):
"""
maximum-compatibility tqdm-based progressbars
"""
Expand Down Expand Up @@ -148,6 +149,8 @@ def __init__(self, tqdm_kwargs=None, *args, **kwargs):
self.files_count = 0
tqdm_kwargs = tqdm_kwargs or {}
tqdm_kwargs.setdefault("postfix", {}).setdefault("files", self.files_count)
kwargs = kwargs or {}
kwargs["tqdm_cls"] = ntqdm
super().__init__(tqdm_kwargs, *args, **kwargs)

def increment_file_count(self, n: int = 1) -> None:
Expand Down
9 changes: 7 additions & 2 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from sqlalchemy.sql.expression import label
from sqlalchemy.sql.schema import TableClause
from sqlalchemy.sql.selectable import Select
from tqdm.auto import tqdm

from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
from datachain.catalog.catalog import clone_catalog_with_cache
Expand Down Expand Up @@ -366,12 +367,16 @@ def get_download_callback(suffix: str = "", **kwargs) -> CombinedDownloadCallbac


def get_processed_callback() -> Callback:
return TqdmCallback({"desc": "Processed", "unit": " rows", "leave": False})
return TqdmCallback(
{"desc": "Processed", "unit": " rows", "leave": False}, tqdm_cls=tqdm
)


def get_generated_callback(is_generator: bool = False) -> Callback:
if is_generator:
return TqdmCallback({"desc": "Generated", "unit": " rows", "leave": False})
return TqdmCallback(
{"desc": "Generated", "unit": " rows", "leave": False}, tqdm_cls=tqdm
)
return DEFAULT_CALLBACK


Expand Down

0 comments on commit 2dd0683

Please sign in to comment.