Skip to content

Commit

Permalink
Update some types to modern Python
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686064548
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Oct 15, 2024
1 parent a57f75f commit 46b867b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 51 deletions.
12 changes: 6 additions & 6 deletions tensorflow_datasets/core/logging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import collections
import functools
import threading
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
from typing import Any, Callable, TypeVar

from absl import flags
from absl import logging
Expand All @@ -34,15 +34,15 @@
_LoggerMethod = Callable[..., None]


_registered_loggers: Optional[List[base_logger.Logger]] = None
_registered_loggers: list[base_logger.Logger] | None = None

_import_operations: List[Tuple[call_metadata.CallMetadata, int, int]] = []
_import_operations: list[tuple[call_metadata.CallMetadata, int, int]] = []
_import_operations_lock = threading.Lock()

_thread_id_to_builder_init_count = collections.Counter()


def _init_registered_loggers() -> List[base_logger.Logger]:
def _init_registered_loggers() -> list[base_logger.Logger]:
"""Initializes the registered loggers if they are not set yet."""
global _registered_loggers
if _registered_loggers is None:
Expand All @@ -65,7 +65,7 @@ def _log_import_operation():
_import_operations.clear()


def _get_registered_loggers() -> List[base_logger.Logger]:
def _get_registered_loggers() -> list[base_logger.Logger]:
_log_import_operation()
return _init_registered_loggers()

Expand Down Expand Up @@ -188,7 +188,7 @@ class _DsbuilderMethodDecorator(_FunctionDecorator):
IS_PROPERTY: bool = False

@staticmethod
def _get_info(dsbuilder: Any) -> Tuple[str, str, str, str]:
def _get_info(dsbuilder: Any) -> tuple[str, str, str, str]:
"""Gets information about the builder.
Args:
Expand Down
90 changes: 45 additions & 45 deletions tensorflow_datasets/core/logging/base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from __future__ import annotations

from typing import Any, Dict, Optional, Union
from typing import Any

from etils import epy

Expand Down Expand Up @@ -51,7 +51,7 @@ def tfds_import(
metadata: call_metadata.CallMetadata,
import_time_ms_tensorflow: int,
import_time_ms_dataset_builders: int,
):
) -> None:
"""Callback called when user calls `import tensorflow_datasets`."""
pass

Expand All @@ -60,11 +60,11 @@ def builder_init(
*,
metadata: call_metadata.CallMetadata,
name: str,
data_dir: Optional[str],
config: Optional[str],
version: Optional[str],
data_dir: str | None,
config: str | None,
version: str | None,
is_read_only_builder: bool,
):
) -> None:
"""Callback called when user calls `DatasetBuilder(...)`."""
pass

Expand All @@ -73,10 +73,10 @@ def builder_info(
*,
metadata: call_metadata.CallMetadata,
name: str,
config_name: Optional[str],
config_name: str | None,
version: str,
data_path: str,
):
) -> None:
"""Callback called when user calls `builder.info()`."""
pass

Expand All @@ -85,16 +85,16 @@ def as_dataset(
*,
metadata: call_metadata.CallMetadata,
name: str,
config_name: Optional[str],
config_name: str | None,
version: str,
data_path: str,
split: Optional[type_utils.Tree[splits_lib.SplitArg]],
batch_size: Optional[int],
split: type_utils.Tree[splits_lib.SplitArg] | None,
batch_size: int | None,
shuffle_files: bool,
read_config: read_config_lib.ReadConfig,
as_supervised: bool,
decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]],
):
decoders: TreeDict[decode.partial_decode.DecoderArg] | None,
) -> None:
"""Callback called when user calls `dataset_builder.as_dataset`.
Callback is also triggered by `tfds.load`, which calls `as_dataset`.
Expand Down Expand Up @@ -122,13 +122,13 @@ def download_and_prepare(
*,
metadata: call_metadata.CallMetadata,
name: str,
config_name: Optional[str],
config_name: str | None,
version: str,
data_path: str,
download_dir: Optional[str],
download_config: Optional[download_lib.DownloadConfig],
file_format: Union[None, str, file_adapters.FileFormat],
):
download_dir: str | None,
download_config: download_lib.DownloadConfig | None,
file_format: str | file_adapters.FileFormat | None,
) -> None:
"""Callback called when user calls `dataset_builder.download_and_prepare`."""
pass

Expand All @@ -141,17 +141,17 @@ def builder(
*,
metadata: call_metadata.CallMetadata,
name: str,
try_gcs: Optional[bool],
):
try_gcs: bool | None,
) -> None:
"""Callback called when user calls `tfds.builder(...)`."""
pass

def dataset_collection(
self,
metadata: call_metadata.CallMetadata,
name: str,
loader_kwargs: Optional[Dict[str, Any]],
):
loader_kwargs: dict[str, Any] | None,
) -> None:
"""Callback called when user calls `tfds.dataset_collection(...)`."""
pass

Expand All @@ -160,26 +160,26 @@ def load(
*,
metadata: call_metadata.CallMetadata,
name: str,
split: Optional[type_utils.Tree[splits_lib.SplitArg]],
data_dir: Optional[str],
batch_size: Optional[int],
shuffle_files: Optional[bool],
download: Optional[bool],
as_supervised: Optional[bool],
decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]],
read_config: Optional[read_config_lib.ReadConfig],
with_info: Optional[bool],
try_gcs: Optional[bool],
):
split: type_utils.Tree[splits_lib.SplitArg] | None,
data_dir: str | None,
batch_size: int | None,
shuffle_files: bool | None,
download: bool | None,
as_supervised: bool | None,
decoders: TreeDict[decode.partial_decode.DecoderArg] | None,
read_config: read_config_lib.ReadConfig | None,
with_info: bool | None,
try_gcs: bool | None,
) -> None:
"""Callback called when user calls `tfds.load(...)`."""
pass

def list_builders(
self,
*,
metadata: call_metadata.CallMetadata,
with_community_datasets: Optional[bool],
):
with_community_datasets: bool | None,
) -> None:
"""Callback called when user calls `tfds.list_builders(...)`."""
pass

Expand All @@ -192,12 +192,12 @@ def data_source(
*,
metadata: call_metadata.CallMetadata,
name: str,
split: Optional[type_utils.Tree[splits_lib.SplitArg]],
data_dir: Optional[str],
download: Optional[bool],
decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]],
try_gcs: Optional[bool],
):
split: type_utils.Tree[splits_lib.SplitArg] | None,
data_dir: str | None,
download: bool | None,
decoders: TreeDict[decode.partial_decode.DecoderArg] | None,
try_gcs: bool | None,
) -> None:
"""Callback called when user calls `tfds.data_source(...)`."""
pass

Expand All @@ -206,11 +206,11 @@ def as_data_source(
*,
metadata: call_metadata.CallMetadata,
name: str,
config_name: Optional[str],
config_name: str | None,
version: str,
data_path: str,
split: Optional[type_utils.Tree[splits_lib.SplitArg]],
decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]],
):
split: type_utils.Tree[splits_lib.SplitArg] | None,
decoders: TreeDict[decode.partial_decode.DecoderArg] | None,
) -> None:
"""Callback called when user calls `dataset_builder.as_data_source(...)`."""
pass

0 comments on commit 46b867b

Please sign in to comment.