diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 0000000..349f3cd --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,15 @@ +name: pre-commit + +on: + pull_request: + push: + branches: [master] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - run: sudo apt-get update + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + - uses: pre-commit/action@v2.0.3 diff --git a/.gitignore b/.gitignore index 1ad9bec..e323068 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ docs/build .vscode/ .idea/ -__pycache__/ \ No newline at end of file +__pycache__/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..33cb19a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,66 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: trailing-whitespace + name: (Common) Remove trailing whitespaces + - id: mixed-line-ending + name: (Common) Fix mixed line ending + args: ['--fix=lf'] + - id: end-of-file-fixer + name: (Common) Remove extra EOF newlines + - id: check-merge-conflict + name: (Common) Check for merge conflicts + - id: requirements-txt-fixer + name: (Common) Sort "requirements.txt" + - id: fix-encoding-pragma + name: (Python) Remove encoding pragmas + args: ['--remove'] + - id: double-quote-string-fixer + name: (Python) Fix double-quoted strings + - id: debug-statements + name: (Python) Check for debugger imports + - id: check-json + name: (JSON) Check syntax + - id: check-yaml + name: (YAML) Check syntax + - id: check-toml + name: (TOML) Check syntax + - repo: https://github.com/executablebooks/mdformat + rev: 0.7.7 + hooks: + - id: mdformat + name: (Markdown) Format with mdformat + - repo: https://github.com/asottile/pyupgrade + rev: v2.19.4 + hooks: + - id: pyupgrade + name: (Python) Update syntax for newer versions + args: ['--py36-plus'] + - repo: https://github.com/google/yapf + rev: v0.31.0 + hooks: + - id: yapf + name: (Python) Format with yapf + - repo: https://github.com/pycqa/isort + rev: 5.8.0 + hooks: + - id: isort + name: (Python) Sort imports with isort (torchpack) + exclude: examples/ + - id: isort + name: (Python) Sort imports with isort (examples) + files: ^examples/ + args: [--sp, examples/setup.cfg] + - repo: https://github.com/pycqa/flake8 + rev: 3.9.2 + hooks: + - id: flake8 + name: (Python) Check with flake8 + additional_dependencies: [flake8-bugbear, flake8-comprehensions, flake8-docstrings, flake8-executable, flake8-quotes] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.902 + hooks: + - id: mypy + name: (Python) Check with mypy + additional_dependencies: [tokenize-rt, types-pyyaml, types-toml] diff --git a/README.md b/README.md index e592a99..b5c03e1 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # Torchpack + Torchpack is a neural network training interface based on PyTorch, with a focus on flexibility. ## Installation @@ -9,8 +10,8 @@ pip install torchpack ## Acknowlegements -* [Tensorpack](https://github.com/tensorpack/tensorpack) -* [Detectron2](https://github.com/facebookresearch/detectron2) -* [Jacinle](https://github.com/vacancy/Jacinle) -* [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) -* [MMCV](https://github.com/open-mmlab/mmcv) +- [Tensorpack](https://github.com/tensorpack/tensorpack) +- [Detectron2](https://github.com/facebookresearch/detectron2) +- [Jacinle](https://github.com/vacancy/Jacinle) +- [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) +- [MMCV](https://github.com/open-mmlab/mmcv) diff --git a/docs/Makefile b/docs/Makefile deleted file mode 100644 index d0c3cbf..0000000 --- a/docs/Makefile +++ /dev/null @@ -1,20 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line, and also -# from the environment for the first two. -SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build -SOURCEDIR = source -BUILDDIR = build - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat deleted file mode 100644 index 9534b01..0000000 --- a/docs/make.bat +++ /dev/null @@ -1,35 +0,0 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set SOURCEDIR=source -set BUILDDIR=build - -if "%1" == "" goto help - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% -goto end - -:help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% - -:end -popd diff --git a/docs/source/conf.py b/docs/source/conf.py deleted file mode 100644 index 0a11aca..0000000 --- a/docs/source/conf.py +++ /dev/null @@ -1,69 +0,0 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import os -import sys -sys.path.insert(0, os.path.abspath('.')) -sys.path.insert(0, os.path.abspath('../')) -sys.path.insert(0, os.path.abspath('../..')) - -# Set masterdoc as per https://github.com/readthedocs/readthedocs.org/issues/2569 -master_doc = 'index' - -# -- Project information ----------------------------------------------------- - -project = 'torchpack' -copyright = '2020, MIT Driverless and MIT HAN Lab' -author = 'Torchpack contributors' - -# The full version, including alpha/beta/rc tags -release = '0.3.0' - - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'sphinx.ext.githubpages', - 'sphinx.ext.autosectionlabel', -] - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -exclude_patterns = [] - - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'sphinx_rtd_theme' - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst deleted file mode 100644 index 433b314..0000000 --- a/docs/source/index.rst +++ /dev/null @@ -1,46 +0,0 @@ -.. torchpack documentation master file, created by - sphinx-quickstart on Fri Nov 27 10:55:54 2020. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - - -Torchpack -========= - -Torchpack is a neural network training interface based on PyTorch, with a focus on flexibility. - -Installation ------------- - -.. code-block:: bash - - pip install torchpack - -Modules ------------------- -.. toctree:: - :maxdepth: 1 - :caption: Contents: - torchpack.callbacks - torchpack.datasets - torchpack.environ - torchpack.launch - torchpack.models - torchpack.nn - torchpack.train - torchpack.utils - -Indices and tables ------------------- - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` - -Acknowlegements ---------------- -* `Tensorpack `_ -* `Detectron2 `_ -* `Jacinle `_ -* `PyTorch Lightning `_ -* `MMCV `_ diff --git a/docs/source/torchpack.callbacks.rst b/docs/source/torchpack.callbacks.rst deleted file mode 100644 index b44bb13..0000000 --- a/docs/source/torchpack.callbacks.rst +++ /dev/null @@ -1,85 +0,0 @@ -torchpack.callbacks package -=========================== - -Submodules ----------- - -torchpack.callbacks.callback module ------------------------------------ - -.. automodule:: torchpack.callbacks.callback - :members: - :undoc-members: - :show-inheritance: - -torchpack.callbacks.checkpoint module -------------------------------------- - -.. automodule:: torchpack.callbacks.checkpoint - :members: - :undoc-members: - :show-inheritance: - -torchpack.callbacks.inference module ------------------------------------- - -.. automodule:: torchpack.callbacks.inference - :members: - :undoc-members: - :show-inheritance: - -torchpack.callbacks.metainfo module ------------------------------------ - -.. automodule:: torchpack.callbacks.metainfo - :members: - :undoc-members: - :show-inheritance: - -torchpack.callbacks.metrics module ----------------------------------- - -.. automodule:: torchpack.callbacks.metrics - :members: - :undoc-members: - :show-inheritance: - -torchpack.callbacks.progress module ------------------------------------ - -.. automodule:: torchpack.callbacks.progress - :members: - :undoc-members: - :show-inheritance: - -torchpack.callbacks.trackers module ------------------------------------ - -.. automodule:: torchpack.callbacks.trackers - :members: - :undoc-members: - :show-inheritance: - -torchpack.callbacks.triggers module ------------------------------------ - -.. automodule:: torchpack.callbacks.triggers - :members: - :undoc-members: - :show-inheritance: - -torchpack.callbacks.writers module ----------------------------------- - -.. automodule:: torchpack.callbacks.writers - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchpack.callbacks - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchpack.datasets.rst b/docs/source/torchpack.datasets.rst deleted file mode 100644 index d322d1e..0000000 --- a/docs/source/torchpack.datasets.rst +++ /dev/null @@ -1,29 +0,0 @@ -torchpack.datasets package -========================== - -Subpackages ------------ - -.. toctree:: - :maxdepth: 4 - - torchpack.datasets.vision - -Submodules ----------- - -torchpack.datasets.dataset module ---------------------------------- - -.. automodule:: torchpack.datasets.dataset - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchpack.datasets - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchpack.datasets.vision.rst b/docs/source/torchpack.datasets.vision.rst deleted file mode 100644 index a7b1bb8..0000000 --- a/docs/source/torchpack.datasets.vision.rst +++ /dev/null @@ -1,29 +0,0 @@ -torchpack.datasets.vision package -================================= - -Submodules ----------- - -torchpack.datasets.vision.cifar module --------------------------------------- - -.. automodule:: torchpack.datasets.vision.cifar - :members: - :undoc-members: - :show-inheritance: - -torchpack.datasets.vision.imagenet module ------------------------------------------ - -.. automodule:: torchpack.datasets.vision.imagenet - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchpack.datasets.vision - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchpack.distributed.rst b/docs/source/torchpack.distributed.rst deleted file mode 100644 index b39486d..0000000 --- a/docs/source/torchpack.distributed.rst +++ /dev/null @@ -1,29 +0,0 @@ -torchpack.distributed package -============================= - -Submodules ----------- - -torchpack.distributed.comm module ---------------------------------- - -.. automodule:: torchpack.distributed.comm - :members: - :undoc-members: - :show-inheritance: - -torchpack.distributed.context module ------------------------------------- - -.. automodule:: torchpack.distributed.context - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchpack.distributed - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchpack.environ.rst b/docs/source/torchpack.environ.rst deleted file mode 100644 index f4ece62..0000000 --- a/docs/source/torchpack.environ.rst +++ /dev/null @@ -1,21 +0,0 @@ -torchpack.environ package -========================= - -Submodules ----------- - -torchpack.environ.rundir module -------------------------------- - -.. automodule:: torchpack.environ.rundir - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchpack.environ - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchpack.launch.assets.rst b/docs/source/torchpack.launch.assets.rst deleted file mode 100644 index 3df472e..0000000 --- a/docs/source/torchpack.launch.assets.rst +++ /dev/null @@ -1,21 +0,0 @@ -torchpack.launch.assets package -=============================== - -Submodules ----------- - -torchpack.launch.assets.silentrun module ----------------------------------------- - -.. automodule:: torchpack.launch.assets.silentrun - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchpack.launch.assets - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchpack.launch.launchers.rst b/docs/source/torchpack.launch.launchers.rst deleted file mode 100644 index b1e2cb0..0000000 --- a/docs/source/torchpack.launch.launchers.rst +++ /dev/null @@ -1,21 +0,0 @@ -torchpack.launch.launchers package -================================== - -Submodules ----------- - -torchpack.launch.launchers.drunner module ------------------------------------------ - -.. automodule:: torchpack.launch.launchers.drunner - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchpack.launch.launchers - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchpack.launch.rst b/docs/source/torchpack.launch.rst deleted file mode 100644 index 174baec..0000000 --- a/docs/source/torchpack.launch.rst +++ /dev/null @@ -1,30 +0,0 @@ -torchpack.launch package -======================== - -Subpackages ------------ - -.. toctree:: - :maxdepth: 4 - - torchpack.launch.assets - torchpack.launch.launchers - -Submodules ----------- - -torchpack.launch.main module ----------------------------- - -.. automodule:: torchpack.launch.main - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchpack.launch - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchpack.models.rst b/docs/source/torchpack.models.rst deleted file mode 100644 index 641c114..0000000 --- a/docs/source/torchpack.models.rst +++ /dev/null @@ -1,29 +0,0 @@ -torchpack.models package -======================== - -Subpackages ------------ - -.. toctree:: - :maxdepth: 4 - - torchpack.models.vision - -Submodules ----------- - -torchpack.models.utils module ------------------------------ - -.. automodule:: torchpack.models.utils - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchpack.models - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchpack.models.vision.rst b/docs/source/torchpack.models.vision.rst deleted file mode 100644 index 143c45d..0000000 --- a/docs/source/torchpack.models.vision.rst +++ /dev/null @@ -1,37 +0,0 @@ -torchpack.models.vision package -=============================== - -Submodules ----------- - -torchpack.models.vision.mobilenetv1 module ------------------------------------------- - -.. automodule:: torchpack.models.vision.mobilenetv1 - :members: - :undoc-members: - :show-inheritance: - -torchpack.models.vision.mobilenetv2 module ------------------------------------------- - -.. automodule:: torchpack.models.vision.mobilenetv2 - :members: - :undoc-members: - :show-inheritance: - -torchpack.models.vision.shufflenetv2 module -------------------------------------------- - -.. automodule:: torchpack.models.vision.shufflenetv2 - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchpack.models.vision - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchpack.nn.functional.rst b/docs/source/torchpack.nn.functional.rst deleted file mode 100644 index 62bd965..0000000 --- a/docs/source/torchpack.nn.functional.rst +++ /dev/null @@ -1,21 +0,0 @@ -torchpack.nn.functional package -=============================== - -Submodules ----------- - -torchpack.nn.functional.index module ------------------------------------- - -.. automodule:: torchpack.nn.functional.index - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchpack.nn.functional - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchpack.nn.rst b/docs/source/torchpack.nn.rst deleted file mode 100644 index a51bae0..0000000 --- a/docs/source/torchpack.nn.rst +++ /dev/null @@ -1,18 +0,0 @@ -torchpack.nn package -==================== - -Subpackages ------------ - -.. toctree:: - :maxdepth: 4 - - torchpack.nn.functional - -Module contents ---------------- - -.. automodule:: torchpack.nn - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchpack.rst b/docs/source/torchpack.rst deleted file mode 100644 index 7a85e87..0000000 --- a/docs/source/torchpack.rst +++ /dev/null @@ -1,37 +0,0 @@ -torchpack package -================= - -Subpackages ------------ - -.. toctree:: - :maxdepth: 4 - - torchpack.callbacks - torchpack.datasets - torchpack.distributed - torchpack.environ - torchpack.launch - torchpack.models - torchpack.nn - torchpack.train - torchpack.utils - -Submodules ----------- - -torchpack.version module ------------------------- - -.. automodule:: torchpack.version - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchpack - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchpack.train.rst b/docs/source/torchpack.train.rst deleted file mode 100644 index d742c4e..0000000 --- a/docs/source/torchpack.train.rst +++ /dev/null @@ -1,37 +0,0 @@ -torchpack.train package -======================= - -Submodules ----------- - -torchpack.train.exception module --------------------------------- - -.. automodule:: torchpack.train.exception - :members: - :undoc-members: - :show-inheritance: - -torchpack.train.summary module ------------------------------- - -.. automodule:: torchpack.train.summary - :members: - :undoc-members: - :show-inheritance: - -torchpack.train.trainer module ------------------------------- - -.. automodule:: torchpack.train.trainer - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchpack.train - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchpack.utils.rst b/docs/source/torchpack.utils.rst deleted file mode 100644 index 1ce2251..0000000 --- a/docs/source/torchpack.utils.rst +++ /dev/null @@ -1,93 +0,0 @@ -torchpack.utils package -======================= - -Submodules ----------- - -torchpack.utils.config module ------------------------------ - -.. automodule:: torchpack.utils.config - :members: - :undoc-members: - :show-inheritance: - -torchpack.utils.device module ------------------------------ - -.. automodule:: torchpack.utils.device - :members: - :undoc-members: - :show-inheritance: - -torchpack.utils.fs module -------------------------- - -.. automodule:: torchpack.utils.fs - :members: - :undoc-members: - :show-inheritance: - -torchpack.utils.git module --------------------------- - -.. automodule:: torchpack.utils.git - :members: - :undoc-members: - :show-inheritance: - -torchpack.utils.humanize module -------------------------------- - -.. automodule:: torchpack.utils.humanize - :members: - :undoc-members: - :show-inheritance: - -torchpack.utils.imp module --------------------------- - -.. automodule:: torchpack.utils.imp - :members: - :undoc-members: - :show-inheritance: - -torchpack.utils.io module -------------------------- - -.. automodule:: torchpack.utils.io - :members: - :undoc-members: - :show-inheritance: - -torchpack.utils.logging module ------------------------------- - -.. automodule:: torchpack.utils.logging - :members: - :undoc-members: - :show-inheritance: - -torchpack.utils.matching module -------------------------------- - -.. automodule:: torchpack.utils.matching - :members: - :undoc-members: - :show-inheritance: - -torchpack.utils.typing module ------------------------------ - -.. automodule:: torchpack.utils.typing - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchpack.utils - :members: - :undoc-members: - :show-inheritance: diff --git a/examples/image-classification/configs/default.yaml b/examples/image-classification/configs/default.yaml deleted file mode 100644 index e634709..0000000 --- a/examples/image-classification/configs/default.yaml +++ /dev/null @@ -1,2 +0,0 @@ -workers_per_gpu: 8 -amp_enabled: false \ No newline at end of file diff --git a/examples/imcls/configs/default.yaml b/examples/imcls/configs/default.yaml new file mode 100644 index 0000000..1ab1bb1 --- /dev/null +++ b/examples/imcls/configs/default.yaml @@ -0,0 +1,3 @@ +workers_per_gpu: 8 +amp: + enabled: false diff --git a/examples/image-classification/configs/imagenet/default.yaml b/examples/imcls/configs/imagenet/default.yaml similarity index 72% rename from examples/image-classification/configs/imagenet/default.yaml rename to examples/imcls/configs/imagenet/default.yaml index c472689..d74bf03 100644 --- a/examples/image-classification/configs/imagenet/default.yaml +++ b/examples/imcls/configs/imagenet/default.yaml @@ -1,4 +1,4 @@ dataset: name: imagenet root: /dataset/imagenet/ - num_classes: 1000 \ No newline at end of file + num_classes: 1000 diff --git a/examples/image-classification/configs/imagenet/mobilenetv2.yaml b/examples/imcls/configs/imagenet/mobilenetv2.yaml similarity index 92% rename from examples/image-classification/configs/imagenet/mobilenetv2.yaml rename to examples/imcls/configs/imagenet/mobilenetv2.yaml index c783ab8..c88cbf1 100644 --- a/examples/image-classification/configs/imagenet/mobilenetv2.yaml +++ b/examples/imcls/configs/imagenet/mobilenetv2.yaml @@ -15,4 +15,4 @@ optimizer: weight_decay: 4.0e-5 scheduler: - name: cosine \ No newline at end of file + name: cosine diff --git a/examples/image-classification/core/__init__.py b/examples/imcls/core/__init__.py similarity index 100% rename from examples/image-classification/core/__init__.py rename to examples/imcls/core/__init__.py diff --git a/examples/image-classification/core/builder.py b/examples/imcls/core/builder.py similarity index 69% rename from examples/image-classification/core/builder.py rename to examples/imcls/core/builder.py index b6feecd..1b862ec 100644 --- a/examples/image-classification/core/builder.py +++ b/examples/imcls/core/builder.py @@ -3,7 +3,6 @@ import torch import torch.optim from torch import nn - from torchpack.datasets.vision import ImageNet from torchpack.models.vision import MobileNetV1, MobileNetV2, ShuffleNetV2 from torchpack.utils.config import configs @@ -17,8 +16,10 @@ def make_dataset() -> Dataset: if configs.dataset.name == 'imagenet': - dataset = ImageNet(root=configs.dataset.root, - num_classes=configs.dataset.num_classes) + dataset = ImageNet( + root=configs.dataset.root, + num_classes=configs.dataset.num_classes, + ) else: raise NotImplementedError(configs.dataset.name) return dataset @@ -26,14 +27,20 @@ def make_dataset() -> Dataset: def make_model() -> nn.Module: if configs.model.name == 'mobilenetv1': - model = MobileNetV1(num_classes=configs.dataset.num_classes, - width_multiplier=configs.model.width_multipler) + model = MobileNetV1( + num_classes=configs.dataset.num_classes, + width_multiplier=configs.model.width_multipler, + ) elif configs.model.name == 'mobilenetv2': - model = MobileNetV2(num_classes=configs.dataset.num_classes, - width_multiplier=configs.model.width_multipler) + model = MobileNetV2( + num_classes=configs.dataset.num_classes, + width_multiplier=configs.model.width_multipler, + ) elif configs.model.name == 'shufflenetv2': - model = ShuffleNetV2(num_classes=configs.dataset.num_classes, - width_multiplier=configs.model.width_multipler) + model = ShuffleNetV2( + num_classes=configs.dataset.num_classes, + width_multiplier=configs.model.width_multipler, + ) else: raise NotImplementedError(configs.model.name) return model @@ -53,7 +60,8 @@ def make_optimizer(model: nn.Module) -> Optimizer: model.parameters(), lr=configs.optimizer.lr, momentum=configs.optimizer.momentum, - weight_decay=configs.optimizer.weight_decay) + weight_decay=configs.optimizer.weight_decay, + ) else: raise NotImplementedError(configs.optimizer.name) return optimizer @@ -62,7 +70,9 @@ def make_optimizer(model: nn.Module) -> Optimizer: def make_scheduler(optimizer: Optimizer) -> Scheduler: if configs.scheduler.name == 'cosine': scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, T_max=configs.num_epochs) + optimizer, + T_max=configs.num_epochs, + ) else: raise NotImplementedError(configs.scheduler.name) return scheduler diff --git a/examples/image-classification/core/trainers.py b/examples/imcls/core/trainers.py similarity index 83% rename from examples/image-classification/core/trainers.py rename to examples/imcls/core/trainers.py index 12e0664..e859571 100644 --- a/examples/image-classification/core/trainers.py +++ b/examples/imcls/core/trainers.py @@ -2,7 +2,6 @@ from torch import nn from torch.cuda import amp - from torchpack.train import Trainer from torchpack.utils.typing import Optimizer, Scheduler @@ -10,13 +9,16 @@ class ClassificationTrainer(Trainer): - def __init__(self, - *, - model: nn.Module, - criterion: Callable, - optimizer: Optimizer, - scheduler: Scheduler, - amp_enabled: bool = False) -> None: + + def __init__( + self, + *, + model: nn.Module, + criterion: Callable, + optimizer: Optimizer, + scheduler: Scheduler, + amp_enabled: bool = False, + ) -> None: self.model = model self.criterion = criterion self.optimizer = optimizer @@ -31,10 +33,10 @@ def _run_step(self, feed_dict: Dict[str, Any]) -> Dict[str, Any]: inputs = feed_dict['image'].cuda(non_blocking=True) targets = feed_dict['class'].cuda(non_blocking=True) - with amp.autocast(enabled=self.amp_enabled): + with amp.autocast(enabled=self.model.training and self.amp_enabled): outputs = self.model(inputs) - if outputs.requires_grad: + if self.model.training: loss = self.criterion(outputs, targets) self.summary.add_scalar('loss', loss.item()) @@ -50,7 +52,7 @@ def _after_epoch(self) -> None: self.scheduler.step() def _state_dict(self) -> Dict[str, Any]: - state_dict = dict() + state_dict = {} state_dict['model'] = self.model.state_dict() state_dict['scaler'] = self.scaler.state_dict() state_dict['optimizer'] = self.optimizer.state_dict() diff --git a/examples/image-classification/train.py b/examples/imcls/train.py similarity index 74% rename from examples/image-classification/train.py rename to examples/imcls/train.py index f5882be..5f3c749 100644 --- a/examples/image-classification/train.py +++ b/examples/imcls/train.py @@ -40,48 +40,55 @@ def main() -> None: logger.info(f'Experiment started: "{args.run_dir}".' + '\n' + f'{configs}') dataset = builder.make_dataset() - dataflow = dict() + dataflow = {} for split in dataset: sampler = torch.utils.data.DistributedSampler( dataset[split], num_replicas=dist.size(), rank=dist.rank(), - shuffle=(split == 'train')) + shuffle=(split == 'train'), + ) dataflow[split] = torch.utils.data.DataLoader( dataset[split], batch_size=configs.batch_size // dist.size(), sampler=sampler, num_workers=configs.workers_per_gpu, - pin_memory=True) + pin_memory=True, + ) model = builder.make_model() model = torch.nn.parallel.DistributedDataParallel( model.cuda(), device_ids=[dist.local_rank()], - find_unused_parameters=True) + ) criterion = builder.make_criterion() optimizer = builder.make_optimizer(model) scheduler = builder.make_scheduler(optimizer) - trainer = ClassificationTrainer(model=model, - criterion=criterion, - optimizer=optimizer, - scheduler=scheduler, - amp_enabled=configs.amp_enabled) + trainer = ClassificationTrainer( + model=model, + criterion=criterion, + optimizer=optimizer, + scheduler=scheduler, + amp_enabled=configs.amp.enabled, + ) trainer.train_with_defaults( dataflow['train'], num_epochs=configs.num_epochs, callbacks=[ SaverRestore(), - InferenceRunner(dataflow['test'], - callbacks=[ - TopKCategoricalAccuracy(k=1, name='acc/top1'), - TopKCategoricalAccuracy(k=5, name='acc/top5') - ]), + InferenceRunner( + dataflow['test'], + callbacks=[ + TopKCategoricalAccuracy(k=1, name='acc/top1'), + TopKCategoricalAccuracy(k=5, name='acc/top5'), + ], + ), MaxSaver('acc/top1'), - Saver() - ]) + Saver(), + ], + ) if __name__ == '__main__': diff --git a/examples/setup.cfg b/examples/setup.cfg new file mode 100644 index 0000000..8299232 --- /dev/null +++ b/examples/setup.cfg @@ -0,0 +1,2 @@ +[isort] +known_first_party = core diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 945c9b4..0000000 --- a/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -. \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..431d7ec --- /dev/null +++ b/setup.cfg @@ -0,0 +1,19 @@ +[yapf] +column_limit = 79 +based_on_style = google +spaces_around_power_operator = true +split_before_arithmetic_operator = true +split_before_logical_operator = true +split_before_bitwise_operator = true + +[isort] +known_first_party = torchpack + +[pydocstyle] +convention = google + +[flake8] +select = B, C, D, E, F, P, T4, W, B9 +ignore = D10, E501, E722, W503 +per-file-ignores = + __init__.py: F401, F403 diff --git a/setup.py b/setup.py index 5f9e46d..e83ea27 100644 --- a/setup.py +++ b/setup.py @@ -10,23 +10,17 @@ author_email='zhijianliu.cs@gmail.com', url='https://github.com/zhijian-liu/torchpack', install_requires=[ - 'h5py', 'loguru', 'multimethod', 'numpy', 'pyyaml', - 'scipy', - 'tensorboard', 'tensorpack', - 'toml', 'torch>=1.5.0', 'torchvision', 'tqdm', ], python_requires='>=3.6', entry_points={ - 'console_scripts': [ - 'torchpack = torchpack.launch:main', - ], + 'console_scripts': ['torchpack = torchpack.launch:main'], }, ) diff --git a/torchpack/callbacks/callback.py b/torchpack/callbacks/callback.py index feef136..5628f92 100644 --- a/torchpack/callbacks/callback.py +++ b/torchpack/callbacks/callback.py @@ -1,16 +1,20 @@ import traceback -from typing import Any, Callable, Dict, List, Optional +import typing +from typing import Any, Callable, Dict, Iterator, List, Optional -from .. import distributed as dist -from ..utils.typing import Trainer +from torchpack import distributed as dist + +if typing.TYPE_CHECKING: + from torchpack.train import Trainer +else: + Trainer = None __all__ = ['Callback', 'LambdaCallback', 'ProxyCallback', 'Callbacks'] class Callback: - """ - Base class for all callbacks. - """ + """Base class for all callbacks.""" + master_only: bool = False @property @@ -30,9 +34,7 @@ def before_train(self) -> None: self._before_train() def _before_train(self) -> None: - """ - Called before training. - """ + """Define what to do before training.""" pass def before_epoch(self) -> None: @@ -40,9 +42,7 @@ def before_epoch(self) -> None: self._before_epoch() def _before_epoch(self) -> None: - """ - Called before every epoch. - """ + """Define what to do before every epoch.""" pass def before_step(self, feed_dict: Dict[str, Any]) -> None: @@ -50,9 +50,7 @@ def before_step(self, feed_dict: Dict[str, Any]) -> None: self._before_step(feed_dict) def _before_step(self, feed_dict: Dict[str, Any]) -> None: - """ - Called before every step. - """ + """Define what to do before every step.""" pass def after_step(self, output_dict: Dict[str, Any]) -> None: @@ -60,9 +58,7 @@ def after_step(self, output_dict: Dict[str, Any]) -> None: self._after_step(output_dict) def _after_step(self, output_dict: Dict[str, Any]) -> None: - """ - Called after every step. - """ + """Define what to do after every step.""" pass def trigger_step(self) -> None: @@ -70,9 +66,7 @@ def trigger_step(self) -> None: self._trigger_step() def _trigger_step(self) -> None: - """ - Called after after step. - """ + """Define what to do after after step.""" pass def after_epoch(self) -> None: @@ -80,9 +74,7 @@ def after_epoch(self) -> None: self._after_epoch() def _after_epoch(self) -> None: - """ - Called after every epoch. - """ + """Define what to do after every epoch.""" pass def trigger_epoch(self) -> None: @@ -90,9 +82,7 @@ def trigger_epoch(self) -> None: self._trigger_epoch() def _trigger_epoch(self) -> None: - """ - Called after after epoch. - """ + """Define what to do after after epoch.""" pass def trigger(self) -> None: @@ -100,10 +90,10 @@ def trigger(self) -> None: self._trigger() def _trigger(self) -> None: - """ - Override this method to define a general trigger behavior, to be used with trigger schedulers. - Note that the schedulers (e.g. :class:`PeriodicTrigger`) might call this method - both inside an epoch and after an epoch. + """Define a general trigger behavior, to be used with trigger schedulers. + + Note that the schedulers (e.g. :class:`PeriodicTrigger`) might call + this method both inside an epoch and after an epoch. """ pass @@ -115,16 +105,14 @@ def after_train(self) -> None: traceback.print_exc() def _after_train(self) -> None: - """ - Called after training. - """ + """Define what to do after training.""" pass def state_dict(self) -> Dict[str, Any]: - return self._state_dict() if self.enabled else dict() + return self._state_dict() if self.enabled else {} def _state_dict(self) -> Dict[str, Any]: - return dict() + return {} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self.enabled: @@ -138,24 +126,25 @@ def __str__(self) -> str: class LambdaCallback(Callback): - """ - A callback created with lambda functions. - """ - def __init__(self, - *, - set_trainer: Optional[Callable] = None, - before_train: Optional[Callable] = None, - before_epoch: Optional[Callable] = None, - before_step: Optional[Callable] = None, - after_step: Optional[Callable] = None, - trigger_step: Optional[Callable] = None, - after_epoch: Optional[Callable] = None, - trigger_epoch: Optional[Callable] = None, - trigger: Optional[Callable] = None, - after_train: Optional[Callable] = None, - state_dict: Optional[Callable] = None, - load_state_dict: Optional[Callable] = None, - master_only: bool = False): + """A callback created with lambda functions.""" + + def __init__( + self, + *, + set_trainer: Optional[Callable] = None, + before_train: Optional[Callable] = None, + before_epoch: Optional[Callable] = None, + before_step: Optional[Callable] = None, + after_step: Optional[Callable] = None, + trigger_step: Optional[Callable] = None, + after_epoch: Optional[Callable] = None, + trigger_epoch: Optional[Callable] = None, + trigger: Optional[Callable] = None, + after_train: Optional[Callable] = None, + state_dict: Optional[Callable] = None, + load_state_dict: Optional[Callable] = None, + master_only: bool = False, + ) -> None: self.set_trainer_fn = set_trainer self.before_train_fn = before_train self.before_epoch_fn = before_epoch @@ -211,7 +200,7 @@ def _after_train(self) -> None: self.after_train_fn(self) def _state_dict(self) -> Dict[str, Any]: - return self.state_dict_fn(self) if self.state_dict_fn else dict() + return self.state_dict_fn(self) if self.state_dict_fn else {} def _load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self.load_state_dict_fn: @@ -219,9 +208,8 @@ def _load_state_dict(self, state_dict: Dict[str, Any]) -> None: class ProxyCallback(Callback): - """ - A callback which proxy all methods to another callback. - """ + """A callback which proxy all methods to another callback.""" + def __init__(self, callback: Callback) -> None: assert isinstance(callback, Callback), type(callback) self.callback = callback @@ -267,9 +255,8 @@ def __str__(self) -> str: class Callbacks(Callback): - """ - A container to hold callbacks. - """ + """A container to hold callbacks.""" + def __init__(self, callbacks: List[Callback]) -> None: for callback in callbacks: assert isinstance(callback, Callback), type(callback) @@ -316,12 +303,12 @@ def _after_train(self) -> None: callback.after_train() def _state_dict(self) -> Dict[str, Any]: - state_dict = dict() + state_dict = {} for k, callback in enumerate(self.callbacks): - local_state = callback.state_dict() - if local_state: + state = callback.state_dict() + if state: name = f'{str(callback).lower()}.{k}' - state_dict[name] = local_state + state_dict[name] = state return state_dict def _load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -335,3 +322,6 @@ def __getitem__(self, index: int) -> Callback: def __len__(self) -> int: return len(self.callbacks) + + def __iter__(self) -> Iterator[Callback]: + return iter(self.callbacks) diff --git a/torchpack/callbacks/checkpoint.py b/torchpack/callbacks/checkpoint.py index 479cbd4..920333c 100644 --- a/torchpack/callbacks/checkpoint.py +++ b/torchpack/callbacks/checkpoint.py @@ -1,32 +1,41 @@ import glob import os +import typing from collections import deque -from typing import Any, ClassVar, Dict, Optional +from typing import Any, ClassVar, Deque, Dict, Optional + +from torchpack.environ import get_run_dir +from torchpack.utils import fs, io +from torchpack.utils.logging import logger -from ..environ import get_run_dir -from ..utils import fs, io -from ..utils.logging import logger -from ..utils.typing import Trainer from .callback import Callback +if typing.TYPE_CHECKING: + from torchpack.train import Trainer +else: + Trainer = None + __all__ = ['Saver', 'MinSaver', 'MaxSaver', 'SaverRestore'] class Saver(Callback): - """ - Save the checkpoint once triggered. - """ + """Save the checkpoint once triggered.""" + master_only: bool = True - def __init__(self, *, max_to_keep: int = 4, - save_dir: Optional[str] = None) -> None: + def __init__( + self, + *, + max_to_keep: int = 4, + save_dir: Optional[str] = None, + ) -> None: self.max_to_keep = max_to_keep if save_dir is None: save_dir = os.path.join(get_run_dir(), 'checkpoints') self.save_dir = fs.normpath(save_dir) def _set_trainer(self, trainer: Trainer) -> None: - self.checkpoints = deque() + self.checkpoints: Deque[str] = deque() for fpath in sorted(glob.glob(os.path.join(self.save_dir, 'step-*.pt')), key=os.path.getmtime): @@ -49,8 +58,8 @@ def _trigger(self) -> None: def _add_checkpoint(self, fpath: str) -> None: self.checkpoints.append(fpath) - while self.max_to_keep is not None and \ - len(self.checkpoints) > self.max_to_keep: + while (self.max_to_keep is not None + and len(self.checkpoints) > self.max_to_keep): fpath = self.checkpoints.popleft() try: fs.remove(fpath) @@ -60,17 +69,18 @@ def _add_checkpoint(self, fpath: str) -> None: class BestSaver(Callback): - """ - Save the checkpoint with best value of some scalar in `trainer.summary`. - """ + """Save the checkpoint with best value of some scalar in `trainer.summary`.""" + master_only: bool = True extreme: ClassVar[str] - def __init__(self, - scalar: str, - *, - name: Optional[str] = None, - save_dir: Optional[str] = None) -> None: + def __init__( + self, + scalar: str, + *, + name: Optional[str] = None, + save_dir: Optional[str] = None, + ) -> None: self.scalar = scalar if name is None: name = self.extreme + '-' + scalar.replace('/', '-') @@ -98,8 +108,9 @@ def _trigger(self): return self.step = step - if self.best is None or (self.extreme == 'min' and value < self.best[1]) \ - or (self.extreme == 'max' and value > self.best[1]): + if (self.best is None + or (self.extreme == 'min' and value < self.best[1]) + or (self.extreme == 'max' and value > self.best[1])): self.best = (step, value) save_path = os.path.join(self.save_dir, self.name + '.pt') try: @@ -122,20 +133,19 @@ def _load_state_dict(self, state_dict: Dict[str, Any]) -> None: class MinSaver(BestSaver): - """ - Save the checkpoint with minimum value of some scalar in `trainer.summary`. - """ + """Save the checkpoint with minimum value of some scalar in `trainer.summary`.""" + extreme: ClassVar[str] = 'min' class MaxSaver(BestSaver): - """ - Save the checkpoint with maximum value of some scalar in `trainer.summary`. - """ + """Save the checkpoint with maximum value of some scalar in `trainer.summary`.""" + extreme: ClassVar[str] = 'max' class SaverRestore(Callback): + def __init__(self, load_dir: Optional[str] = None) -> None: if load_dir is None: load_dir = os.path.join(get_run_dir(), 'checkpoints') diff --git a/torchpack/callbacks/inference.py b/torchpack/callbacks/inference.py index fbd1d38..b54edbf 100644 --- a/torchpack/callbacks/inference.py +++ b/torchpack/callbacks/inference.py @@ -1,22 +1,25 @@ import time +import typing from typing import List import torch from torch.utils.data import DataLoader -from ..utils import humanize -from ..utils.logging import logger -from ..utils import tqdm -from ..utils.typing import Trainer +from torchpack.utils import humanize, tqdm +from torchpack.utils.logging import logger + from .callback import Callback, Callbacks +if typing.TYPE_CHECKING: + from torchpack.train import Trainer +else: + Trainer = None + __all__ = ['InferenceRunner'] class InferenceRunner(Callback): - """ - A callback that runs inference with a list of :class:`Callback`. - """ + """Run inference with a list of :class:`Callback`.""" def __init__(self, dataflow: DataLoader, *, callbacks: List[Callback]) -> None: diff --git a/torchpack/callbacks/metainfo.py b/torchpack/callbacks/metainfo.py index 61e5d62..999d7d2 100644 --- a/torchpack/callbacks/metainfo.py +++ b/torchpack/callbacks/metainfo.py @@ -1,9 +1,10 @@ import os -from typing import Optional +from typing import Dict, Optional + +from torchpack.environ import get_run_dir +from torchpack.utils import fs, git, io +from torchpack.utils.config import configs -from ..environ import get_run_dir -from ..utils import fs, git, io -from ..utils.config import configs from .callback import Callback __all__ = ['MetaInfoSaver'] @@ -19,15 +20,19 @@ def __init__(self, save_dir: Optional[str] = None) -> None: def _before_train(self) -> None: if configs: - io.save(os.path.join(self.save_dir, 'configs.yaml'), - configs.dict()) + io.save( + os.path.join(self.save_dir, 'configs.yaml'), + configs.dict(), + ) if git.is_inside_work_tree(): - metainfo = dict() + metainfo: Dict[str, Optional[str]] = {} remote = git.get_remote_url() if remote: metainfo['remote'] = remote metainfo['commit'] = git.get_commit_hash() - io.save(os.path.join(self.save_dir, 'git.json'), - metainfo, - indent=4) + io.save( + os.path.join(self.save_dir, 'git.json'), + metainfo, + indent=4, + ) diff --git a/torchpack/callbacks/metrics.py b/torchpack/callbacks/metrics.py index 0ecbad8..45fc426 100644 --- a/torchpack/callbacks/metrics.py +++ b/torchpack/callbacks/metrics.py @@ -2,7 +2,8 @@ import torch -from .. import distributed as dist +from torchpack import distributed as dist + from .callback import Callback __all__ = [ @@ -12,12 +13,15 @@ class TopKCategoricalAccuracy(Callback): - def __init__(self, - k: int, - *, - output_tensor: str = 'outputs', - target_tensor: str = 'targets', - name: str = 'accuracy') -> None: + + def __init__( + self, + k: int, + *, + output_tensor: str = 'outputs', + target_tensor: str = 'targets', + name: str = 'accuracy', + ) -> None: self.k = k self.output_tensor = output_tensor self.target_tensor = target_tensor @@ -45,23 +49,31 @@ def _after_epoch(self) -> None: class CategoricalAccuracy(TopKCategoricalAccuracy): - def __init__(self, - *, - output_tensor: str = 'outputs', - target_tensor: str = 'targets', - name: str = 'accuracy') -> None: - super().__init__(k=1, - output_tensor=output_tensor, - target_tensor=target_tensor, - name=name) + + def __init__( + self, + *, + output_tensor: str = 'outputs', + target_tensor: str = 'targets', + name: str = 'accuracy', + ) -> None: + super().__init__( + k=1, + output_tensor=output_tensor, + target_tensor=target_tensor, + name=name, + ) class MeanSquaredError(Callback): - def __init__(self, - *, - output_tensor: str = 'outputs', - target_tensor: str = 'targets', - name: str = 'error') -> None: + + def __init__( + self, + *, + output_tensor: str = 'outputs', + target_tensor: str = 'targets', + name: str = 'error', + ) -> None: self.output_tensor = output_tensor self.target_tensor = target_tensor self.name = name @@ -86,11 +98,14 @@ def _after_epoch(self) -> None: class MeanAbsoluteError(Callback): - def __init__(self, - *, - output_tensor: str = 'outputs', - target_tensor: str = 'targets', - name: str = 'error') -> None: + + def __init__( + self, + *, + output_tensor: str = 'outputs', + target_tensor: str = 'targets', + name: str = 'error', + ) -> None: self.output_tensor = output_tensor self.target_tensor = target_tensor self.name = name diff --git a/torchpack/callbacks/progress.py b/torchpack/callbacks/progress.py index 9163074..80e456b 100644 --- a/torchpack/callbacks/progress.py +++ b/torchpack/callbacks/progress.py @@ -1,21 +1,21 @@ import time from collections import deque -from typing import List, Union +from typing import Deque, List, Union import numpy as np -from ..utils import humanize, tqdm -from ..utils.logging import logger -from ..utils.matching import NameMatcher +from torchpack.utils import humanize, tqdm +from torchpack.utils.logging import logger +from torchpack.utils.matching import NameMatcher + from .callback import Callback __all__ = ['ProgressBar', 'EstimatedTimeLeft'] class ProgressBar(Callback): - """ - A progress bar based on `tqdm`. - """ + """A progress bar based on `tqdm`.""" + master_only: bool = True def __init__(self, scalars: Union[str, List[str]] = '*') -> None: @@ -25,12 +25,12 @@ def _before_epoch(self) -> None: self.pbar = tqdm.trange(self.trainer.steps_per_epoch) def _trigger_step(self) -> None: - texts = [] + texts: List[str] = [] for name in sorted(self.trainer.summary.keys()): step, scalar = self.trainer.summary[name][-1] if self.matcher.match(name) and step == self.trainer.global_step and \ isinstance(scalar, (int, float)): - texts.append('[{}] = {:.3g}'.format(name, scalar)) + texts.append(f'[{name}] = {scalar:.3g}') if texts: self.pbar.set_description(', '.join(texts)) self.pbar.update() @@ -40,16 +40,15 @@ def _after_epoch(self) -> None: class EstimatedTimeLeft(Callback): - """ - Estimate the time left until completion. - """ + """Estimate the time left until completion.""" + master_only: bool = True def __init__(self, *, last_k_epochs: int = 8) -> None: self.last_k_epochs = last_k_epochs def _before_train(self) -> None: - self.times = deque(maxlen=self.last_k_epochs) + self.times: Deque[float] = deque(maxlen=self.last_k_epochs) self.last_time = time.perf_counter() def _trigger_epoch(self) -> None: @@ -57,7 +56,7 @@ def _trigger_epoch(self) -> None: self.times.append(time.perf_counter() - self.last_time) self.last_time = time.perf_counter() - estimated_time = (self.trainer.num_epochs - - self.trainer.epoch_num) * np.mean(self.times) + estimated_time = (self.trainer.num_epochs + - self.trainer.epoch_num) * np.mean(self.times) logger.info('Estimated time left: {}.'.format( humanize.naturaldelta(estimated_time))) diff --git a/torchpack/callbacks/trackers.py b/torchpack/callbacks/trackers.py index 22d011e..c68b039 100644 --- a/torchpack/callbacks/trackers.py +++ b/torchpack/callbacks/trackers.py @@ -1,7 +1,7 @@ import multiprocessing as mp import os import time -from queue import Empty +from queue import Empty, Queue from typing import List, Optional import numpy as np @@ -10,19 +10,21 @@ start_proc_mask_signal) from tensorpack.utils.nvml import NVMLContext -from ..utils.logging import logger +from torchpack.utils.logging import logger + from .callback import Callback __all__ = ['GPUUtilizationTracker', 'ThroughputTracker'] class GPUUtilizationTracker(Callback): + """Track the average GPU utilization within an epoch. + + It will start a process to track GPU utilization through NVML every second + within the epoch (the time of `trigger_epoch` is not included). This + callback creates a process, therefore it is not safe to be used with MPI. """ - Track the average GPU utilization within an epoch. - It will start a process to track GPU utilization through NVML - every second within the epoch (the time of `trigger_epoch` is not included). - This callback creates a process, therefore it is not safe to be used with MPI. - """ + master_only: bool = True def __init__(self, *, devices: Optional[List[int]] = None) -> None: @@ -58,14 +60,16 @@ def _worker(devices, queue, event): meters = meters[:max(len(meters) - 1, 1)] queue.put(np.mean(meters, axis=0)) event.clear() - except: + except Exception: queue.put(None) def _before_train(self) -> None: - self.queue = mp.Queue() + self.queue: Queue[np.ndarray] = mp.Queue() self.event = mp.Event() - self.process = mp.Process(target=self._worker, - args=(self.devices, self.queue, self.event)) + self.process = mp.Process( + target=self._worker, + args=(self.devices, self.queue, self.event), + ) ensure_proc_terminate(self.process) start_proc_mask_signal(self.process) @@ -90,7 +94,9 @@ def _trigger_epoch(self) -> None: if len(self.devices) > 1: for k, device in enumerate(self.devices): self.trainer.summary.add_scalar( - 'utilization/gpu{}'.format(device), meters[k]) + f'utilization/gpu{device}', + meters[k], + ) def _after_train(self) -> None: if self.process.is_alive(): @@ -98,9 +104,11 @@ def _after_train(self) -> None: class ThroughputTracker(Callback): + """Track the throughput within an epoch. + + Note that the time of `trigger_epoch` is not included. """ - Track the throughput within an epoch (the time of `trigger_epoch` is not included). - """ + master_only: bool = True def __init__(self, *, samples_per_step: Optional[int] = None) -> None: @@ -116,14 +124,18 @@ def _after_epoch(self) -> None: self.end_time = time.perf_counter() def _trigger_epoch(self) -> None: - steps_per_sec = (self.trainer.global_step - - self.last_step) / (self.end_time - self.start_time) + steps_per_sec = (self.trainer.global_step + - self.last_step) / (self.end_time - self.start_time) self.last_step = self.trainer.global_step if self.samples_per_step is None: - self.trainer.summary.add_scalar('throughput/steps_per_sec', - steps_per_sec) + self.trainer.summary.add_scalar( + 'throughput/steps_per_sec', + steps_per_sec, + ) else: samples_per_sec = steps_per_sec * self.samples_per_step - self.trainer.summary.add_scalar('throughput/samples_per_sec', - samples_per_sec) + self.trainer.summary.add_scalar( + 'throughput/samples_per_sec', + samples_per_sec, + ) diff --git a/torchpack/callbacks/triggers.py b/torchpack/callbacks/triggers.py index a42e558..e327af0 100644 --- a/torchpack/callbacks/triggers.py +++ b/torchpack/callbacks/triggers.py @@ -6,11 +6,13 @@ class EnableCallbackIf(ProxyCallback): - """ - Enable the callback only if some condition holds. - """ - def __init__(self, callback: Callback, - predicate: Callable[[Callback], bool]) -> None: + """Enable the callback only if some condition holds.""" + + def __init__( + self, + callback: Callback, + predicate: Callable[[Callback], bool], + ) -> None: super().__init__(callback) self.predicate = predicate @@ -43,14 +45,15 @@ def __str__(self) -> str: class PeriodicTrigger(ProxyCallback): - """ - Trigger the callback every k steps or every k epochs. - """ - def __init__(self, - callback: Callback, - *, - every_k_epochs: Optional[int] = None, - every_k_steps: Optional[int] = None) -> None: + """Trigger the callback every k steps or every k epochs.""" + + def __init__( + self, + callback: Callback, + *, + every_k_epochs: Optional[int] = None, + every_k_steps: Optional[int] = None, + ) -> None: super().__init__(callback) assert every_k_epochs is not None or every_k_steps is not None, \ '`every_k_epochs` and `every_k_steps` cannot both be None!' @@ -58,11 +61,13 @@ def __init__(self, self.every_k_steps = every_k_steps def _trigger_step(self) -> None: - if self.every_k_steps is not None and self.trainer.global_step % self.every_k_steps == 0: + if (self.every_k_steps is not None + and self.trainer.global_step % self.every_k_steps == 0): super()._trigger() def _trigger_epoch(self) -> None: - if self.every_k_epochs is not None and self.trainer.epoch_num % self.every_k_epochs == 0: + if (self.every_k_epochs is not None + and self.trainer.epoch_num % self.every_k_epochs == 0): super()._trigger() def __str__(self) -> str: @@ -70,24 +75,29 @@ def __str__(self) -> str: class PeriodicCallback(EnableCallbackIf): - """ - Enable the callback every k steps or every k epochs. + """Enable the callback every k steps or every k epochs. + Note that this can only make a callback less frequent. """ - def __init__(self, - callback: Callback, - *, - every_k_epochs: Optional[int] = None, - every_k_steps: Optional[int] = None) -> None: + + def __init__( + self, + callback: Callback, + *, + every_k_epochs: Optional[int] = None, + every_k_steps: Optional[int] = None, + ) -> None: assert every_k_epochs is not None or every_k_steps is not None, \ '`every_k_epochs` and `every_k_steps` cannot both be None!' self.every_k_epochs = every_k_epochs self.every_k_steps = every_k_steps def predicate(self) -> bool: - if self.every_k_epochs is not None and self.trainer.epoch_num % self.every_k_epochs == 0: + if (self.every_k_epochs is not None + and self.trainer.epoch_num % self.every_k_epochs == 0): return True - if self.every_k_steps is not None and self.trainer.global_step % self.every_k_steps == 0: + if (self.every_k_steps is not None + and self.trainer.global_step % self.every_k_steps == 0): return True return False diff --git a/torchpack/callbacks/writers.py b/torchpack/callbacks/writers.py index 8b81721..0e34818 100644 --- a/torchpack/callbacks/writers.py +++ b/torchpack/callbacks/writers.py @@ -1,23 +1,28 @@ import json import os -from typing import List, Optional, Union +import typing +from typing import Dict, List, Optional, Union import numpy as np -from ..environ import get_run_dir -from ..utils import fs -from ..utils.logging import logger -from ..utils.matching import NameMatcher -from ..utils.typing import Trainer +from torchpack.environ import get_run_dir +from torchpack.utils import fs +from torchpack.utils.logging import logger +from torchpack.utils.matching import NameMatcher + from .callback import Callback +if typing.TYPE_CHECKING: + from torchpack.train import Trainer +else: + Trainer = None + __all__ = ['SummaryWriter', 'ConsoleWriter', 'TFEventWriter', 'JSONLWriter'] class SummaryWriter(Callback): - """ - Base class for all summary writers. - """ + """Base class for all summary writers.""" + master_only: bool = True def add_scalar(self, name: str, scalar: Union[int, float]) -> None: @@ -36,14 +41,13 @@ def _add_image(self, name: str, tensor: np.ndarray) -> None: class ConsoleWriter(SummaryWriter): - """ - Write scalar summaries to console (and logger). - """ + """Write scalar summaries to console (and logger).""" + def __init__(self, scalars: Union[str, List[str]] = '*') -> None: self.matcher = NameMatcher(patterns=scalars) def _set_trainer(self, trainer: Trainer) -> None: - self.scalars = dict() + self.scalars: Dict[str, Union[int, float]] = {} def _add_scalar(self, name: str, scalar: Union[int, float]) -> None: self.scalars[name] = scalar @@ -55,16 +59,15 @@ def _trigger(self) -> None: texts = [] for name, scalar in sorted(self.scalars.items()): if self.matcher.match(name): - texts.append('[{}] = {:.5g}'.format(name, scalar)) + texts.append(f'[{name}] = {scalar:.5g}') if texts: logger.info('\n+ '.join([''] + texts)) self.scalars.clear() class TFEventWriter(SummaryWriter): - """ - Write summaries to TensorFlow event file. - """ + """Write summaries to TensorFlow event file.""" + def __init__(self, *, save_dir: Optional[str] = None) -> None: if save_dir is None: save_dir = os.path.join(get_run_dir(), 'tensorboard') @@ -85,16 +88,15 @@ def _after_train(self) -> None: class JSONLWriter(SummaryWriter): - """ - Write scalar summaries to JSONL file. - """ + """Write scalar summaries to JSONL file.""" + def __init__(self, save_dir: Optional[str] = None) -> None: if save_dir is None: save_dir = os.path.join(get_run_dir(), 'summary') self.save_dir = fs.normpath(save_dir) def _set_trainer(self, trainer: Trainer) -> None: - self.scalars = dict() + self.scalars: Dict[str, Union[int, float]] = {} fs.makedir(self.save_dir) self.file = open(os.path.join(self.save_dir, 'scalars.jsonl'), 'a') @@ -109,7 +111,7 @@ def _trigger_epoch(self) -> None: def _trigger(self) -> None: if self.scalars: - summary = { + summary: Dict[str, Union[int, float]] = { 'epoch_num': self.trainer.epoch_num, 'local_step': self.trainer.local_step, 'global_step': self.trainer.global_step, diff --git a/torchpack/datasets/vision/cifar.py b/torchpack/datasets/vision/cifar.py index 18fe975..07cc0ea 100644 --- a/torchpack/datasets/vision/cifar.py +++ b/torchpack/datasets/vision/cifar.py @@ -10,18 +10,23 @@ class CIFAR10Dataset(datasets.CIFAR10): - def __init__(self, - *, - root: str, - split: str, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = True) -> None: - super().__init__(root=root, - train=(split == 'train'), - transform=transform, - target_transform=target_transform, - download=download) + + def __init__( + self, + *, + root: str, + split: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = True, + ) -> None: + super().__init__( + root=root, + train=(split == 'train'), + transform=transform, + target_transform=target_transform, + download=download, + ) def __getitem__(self, index: int) -> Dict[str, Any]: image, label = super().__getitem__(index) @@ -29,18 +34,23 @@ def __getitem__(self, index: int) -> Dict[str, Any]: class CIFAR100Dataset(datasets.CIFAR100): - def __init__(self, - *, - root: str, - split: str, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = True) -> None: - super().__init__(root=root, - train=(split == 'train'), - transform=transform, - target_transform=target_transform, - download=download) + + def __init__( + self, + *, + root: str, + split: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = True, + ) -> None: + super().__init__( + root=root, + train=(split == 'train'), + transform=transform, + target_transform=target_transform, + download=download, + ) def __getitem__(self, index: int) -> Dict[str, Any]: image, label = super().__getitem__(index) @@ -48,11 +58,14 @@ def __getitem__(self, index: int) -> Dict[str, Any]: class CIFAR(Dataset): - def __init__(self, - *, - root: str, - num_classes: int = 10, - transforms: Optional[Dict[str, Callable]] = None) -> None: + + def __init__( + self, + *, + root: str, + num_classes: int = 10, + transforms: Optional[Dict[str, Callable]] = None, + ) -> None: if num_classes == 10: CIFARDataset = CIFAR10Dataset elif num_classes == 100: @@ -61,26 +74,31 @@ def __init__(self, raise NotImplementedError(f'CIFAR-{num_classes} is not supported.') if transforms is None: - transforms = dict() + transforms = {} if 'train' not in transforms: transforms['train'] = Compose([ RandomCrop(32, padding=4), RandomHorizontalFlip(), ToTensor(), - Normalize(mean=[0.4914, 0.4822, 0.4465], - std=[0.2023, 0.1994, 0.2010]) + Normalize( + mean=[0.4914, 0.4822, 0.4465], + std=[0.2023, 0.1994, 0.2010], + ), ]) if 'test' not in transforms: transforms['test'] = Compose([ Resize(32), ToTensor(), - Normalize(mean=[0.4914, 0.4822, 0.4465], - std=[0.2023, 0.1994, 0.2010]) + Normalize( + mean=[0.4914, 0.4822, 0.4465], + std=[0.2023, 0.1994, 0.2010], + ), ]) super().__init__({ - split: CIFARDataset(root=root, - split=split, - transform=transforms[split]) - for split in ['train', 'test'] + split: CIFARDataset( + root=root, + split=split, + transform=transforms[split], + ) for split in ['train', 'test'] }) diff --git a/torchpack/datasets/vision/imagenet.py b/torchpack/datasets/vision/imagenet.py index e2472cc..375313a 100644 --- a/torchpack/datasets/vision/imagenet.py +++ b/torchpack/datasets/vision/imagenet.py @@ -12,16 +12,21 @@ class ImageNetDataset(datasets.ImageNet): - def __init__(self, - *, - root: str, - split: str, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None) -> None: - super().__init__(root=root, - split=('train' if split == 'train' else 'val'), - transform=transform, - target_transform=target_transform) + + def __init__( + self, + *, + root: str, + split: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__( + root=root, + split=('train' if split == 'train' else 'val'), + transform=transform, + target_transform=target_transform, + ) def __getitem__(self, index: int) -> Dict[str, Any]: with warnings.catch_warnings(): @@ -31,38 +36,46 @@ def __getitem__(self, index: int) -> Dict[str, Any]: class ImageNet(Dataset): - def __init__(self, - *, - root: str, - num_classes: int = 1000, - transforms: Optional[Dict[str, Callable]] = None) -> None: + + def __init__( + self, + *, + root: str, + num_classes: int = 1000, + transforms: Optional[Dict[str, Callable]] = None, + ) -> None: if transforms is None: - transforms = dict() + transforms = {} if 'train' not in transforms: transforms['train'] = Compose([ RandomResizedCrop(224), RandomHorizontalFlip(), ToTensor(), - Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) + Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ) ]) if 'test' not in transforms: transforms['test'] = Compose([ Resize(256), CenterCrop(224), ToTensor(), - Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) + Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ) ]) super().__init__({ - split: ImageNetDataset(root=root, - split=split, - transform=transforms[split]) - for split in ['train', 'test'] + split: ImageNetDataset( + root=root, + split=split, + transform=transforms[split], + ) for split in ['train', 'test'] }) - indices = dict() + indices = {} for k in range(num_classes): indices[k * (1000 // num_classes)] = k diff --git a/torchpack/distributed/comm.py b/torchpack/distributed/comm.py index a0764ef..e5cdc3d 100644 --- a/torchpack/distributed/comm.py +++ b/torchpack/distributed/comm.py @@ -33,9 +33,9 @@ def allgather(data: Any) -> List: max_size = max(sizes) # receiving tensors from all ranks - tensors = [torch.ByteTensor(size=(max_size, )).cuda() for _ in sizes] + tensors = [torch.ByteTensor(size=(max_size,)).cuda() for _ in sizes] if local_size != max_size: - padding = torch.ByteTensor(size=(max_size - local_size, )).cuda() + padding = torch.ByteTensor(size=(max_size - local_size,)).cuda() tensor = torch.cat((tensor, padding), dim=0) torch.distributed.all_gather(tensors, tensor) diff --git a/torchpack/distributed/context.py b/torchpack/distributed/context.py index 828b1ab..b424ea4 100644 --- a/torchpack/distributed/context.py +++ b/torchpack/distributed/context.py @@ -4,14 +4,19 @@ import torch.distributed from torch.distributed.constants import default_pg_timeout +from torchpack.utils.logging import logger +from torchpack.utils.network import get_free_tcp_port + __all__ = ['init', 'size', 'rank', 'local_size', 'local_rank', 'is_master'] _world_size, _world_rank = 1, 0 _local_size, _local_rank = 1, 0 -def init(backend: int = 'nccl', - timeout: timedelta = default_pg_timeout) -> None: +def init( + backend: str = 'nccl', + timeout: timedelta = default_pg_timeout, +) -> None: from mpi4py import MPI world_comm = MPI.COMM_WORLD local_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED) @@ -23,14 +28,17 @@ def init(backend: int = 'nccl', if 'MASTER_HOST' in os.environ: master_host = 'tcp://' + os.environ['MASTER_HOST'] else: - from torchpack.launch.launchers.drunner import get_free_tcp_port - master_host = 'tcp://localhost:{}'.format(get_free_tcp_port()) - print("Distributed environment not detected, fall back to default") - torch.distributed.init_process_group(backend=backend, - init_method=master_host, - timeout=timeout, - world_size=_world_size, - rank=_world_rank) + master_host = f'tcp://localhost:{get_free_tcp_port()}' + logger.warning( + 'Distributed environment not detected, fall back to default') + + torch.distributed.init_process_group( + backend=backend, + init_method=master_host, + timeout=timeout, + world_size=_world_size, + rank=_world_rank, + ) def size() -> int: diff --git a/torchpack/environ/__init__.py b/torchpack/environ/__init__.py index f3f6b69..d55d27c 100644 --- a/torchpack/environ/__init__.py +++ b/torchpack/environ/__init__.py @@ -1 +1 @@ -from .rundir import * \ No newline at end of file +from .rundir import * diff --git a/torchpack/environ/rundir.py b/torchpack/environ/rundir.py index ec911c6..2e095c3 100644 --- a/torchpack/environ/rundir.py +++ b/torchpack/environ/rundir.py @@ -1,15 +1,18 @@ import os -from .. import distributed as dist -from ..utils import fs, git -from ..utils.config import configs -from ..utils.logging import logger +from torchpack import distributed as dist +from torchpack.utils import fs, git +from torchpack.utils.config import configs +from torchpack.utils.logging import logger __all__ = ['get_run_dir', 'set_run_dir', 'auto_set_run_dir'] +_run_dir = None + def get_run_dir() -> str: global _run_dir + assert _run_dir is not None return _run_dir @@ -20,17 +23,21 @@ def set_run_dir(dirpath: str) -> None: prefix = '{time}' if dist.size() > 1: - prefix += '_{:04d}'.format(dist.rank()) - logger.add(os.path.join(_run_dir, 'logging', prefix + '.log'), - format=('{time:YYYY-MM-DD HH:mm:ss.SSS} | ' - '{name}:{function}:{line} | ' - '{level} | {message}')) + prefix += f'_{dist.rank():04d}' + logger.add( + os.path.join(_run_dir, 'logging', prefix + '.log'), + format=('{time:YYYY-MM-DD HH:mm:ss.SSS} | ' + '{name}:{function}:{line} | ' + '{level} | {message}'), + ) def auto_set_run_dir() -> str: tags = ['run'] if git.is_inside_work_tree(): - tags.append(git.get_commit_hash()[:8]) + hash = git.get_commit_hash() + if hash: + tags.append(hash[:8]) if configs: tags.append(configs.hash()[:8]) run_dir = os.path.join('runs', '-'.join(tags)) diff --git a/torchpack/launch/__init__.py b/torchpack/launch/__init__.py index c313e3a..15b6a64 100644 --- a/torchpack/launch/__init__.py +++ b/torchpack/launch/__init__.py @@ -1 +1 @@ -from .main import * \ No newline at end of file +from .main import * diff --git a/torchpack/launch/launchers/drunner.py b/torchpack/launch/launchers/drunner.py index 053647c..8d6225c 100644 --- a/torchpack/launch/launchers/drunner.py +++ b/torchpack/launch/launchers/drunner.py @@ -1,10 +1,11 @@ import argparse import os import re -import socket import sys from shlex import quote +from torchpack.utils.network import get_free_tcp_port + __all__ = ['main'] @@ -13,13 +14,6 @@ def is_exportable(v): return not any(re.match(r, v) for r in IGNORE_REGEXES) -def get_free_tcp_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp: - tcp.bind(('0.0.0.0', 0)) - port = tcp.getsockname()[1] - return port - - def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( @@ -58,7 +52,7 @@ def main() -> None: if not args.hosts: if args.hostfile: hosts = [] - with open(args.hostfile, 'r') as fp: + with open(args.hostfile) as fp: for line in fp.read().splitlines(): hostname = line.split()[0] slots = line.split('=')[1] @@ -91,13 +85,13 @@ def main() -> None: '{environ} ' '-mca pml ob1 -mca btl ^openib ' '-mca btl_tcp_if_exclude docker0,lo ' - '{command}'.format(nproc=args.nproc, - hosts=args.hosts, - environ=' '.join( - f'-x {key}' - for key in sorted(environ.keys()) - if is_exportable(key)), - command=command)) + '{command}'.format( + nproc=args.nproc, + hosts=args.hosts, + environ=' '.join( + f'-x {key}' for key in sorted(environ.keys()) + if is_exportable(key)), + command=command)) if args.verbose: print(command) diff --git a/torchpack/models/utils.py b/torchpack/models/utils.py index 17e9bc7..f9eef2f 100644 --- a/torchpack/models/utils.py +++ b/torchpack/models/utils.py @@ -3,9 +3,13 @@ __all__ = ['make_divisible'] -# from https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py -def make_divisible(v: int, divisor: int, *, - min_value: Optional[int] = None) -> int: +# from https://tinyurl.com/vke23tt5 +def make_divisible( + v: int, + divisor: int, + *, + min_value: Optional[int] = None, +) -> int: if min_value is None: min_value = divisor x = max(min_value, int(v + divisor / 2) // divisor * divisor) diff --git a/torchpack/models/vision/mobilenetv1.py b/torchpack/models/vision/mobilenetv1.py index e4fba35..69bd10d 100644 --- a/torchpack/models/vision/mobilenetv1.py +++ b/torchpack/models/vision/mobilenetv1.py @@ -1,4 +1,4 @@ -from typing import ClassVar, List, Tuple, Union +from typing import ClassVar, List import torch from torch import nn @@ -9,25 +9,30 @@ class MobileBlockV1(nn.Sequential): - def __init__(self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - *, - stride: int = 1) -> None: + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + *, + stride: int = 1, + ) -> None: self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride super().__init__( - nn.Conv2d(in_channels, - in_channels, - kernel_size, - stride=stride, - padding=kernel_size // 2, - groups=in_channels, - bias=False), + nn.Conv2d( + in_channels, + in_channels, + kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=in_channels, + bias=False, + ), nn.BatchNorm2d(in_channels), nn.ReLU(True), nn.Conv2d(in_channels, out_channels, 1, bias=False), @@ -41,22 +46,29 @@ class MobileNetV1(nn.Module): 32, (64, 1, 1), (128, 2, 2), (256, 2, 2), (512, 6, 2), (1024, 2, 2) ] - def __init__(self, - *, - in_channels: int = 3, - num_classes: int = 1000, - width_multiplier: float = 1) -> None: + def __init__( + self, + *, + in_channels: int = 3, + num_classes: int = 1000, + width_multiplier: float = 1, + ) -> None: super().__init__() - out_channels = make_divisible(self.layers[0] * width_multiplier, 8) + out_channels = make_divisible( + int(self.layers[0] * width_multiplier), + divisor=8, + ) layers = nn.ModuleList([ nn.Sequential( - nn.Conv2d(in_channels, - out_channels, - 3, - stride=2, - padding=1, - bias=False), + nn.Conv2d( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + bias=False, + ), nn.BatchNorm2d(out_channels), nn.ReLU(True), ) @@ -64,7 +76,10 @@ def __init__(self, in_channels = out_channels for out_channels, num_blocks, strides in self.layers[1:]: - out_channels = make_divisible(out_channels * width_multiplier, 8) + out_channels = make_divisible( + int(out_channels * width_multiplier), + divisor=8, + ) for stride in [strides] + [1] * (num_blocks - 1): layers.append( MobileBlockV1(in_channels, out_channels, 3, stride=stride)) @@ -77,9 +92,11 @@ def __init__(self, def reset_parameters(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, - mode='fan_out', - nonlinearity='relu') + nn.init.kaiming_normal_( + m.weight, + mode='fan_out', + nonlinearity='relu', + ) if m.bias is not None: nn.init.zeros_(m.bias) if isinstance(m, nn.Linear): diff --git a/torchpack/models/vision/mobilenetv2.py b/torchpack/models/vision/mobilenetv2.py index 4cd9391..076b645 100644 --- a/torchpack/models/vision/mobilenetv2.py +++ b/torchpack/models/vision/mobilenetv2.py @@ -1,4 +1,4 @@ -from typing import ClassVar, List, Tuple, Union +from typing import ClassVar, List import torch from torch import nn @@ -9,13 +9,16 @@ class MobileBlockV2(nn.Module): - def __init__(self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - *, - stride: int = 1, - expansion: int = 1) -> None: + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + *, + stride: int = 1, + expansion: int = 1, + ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -25,13 +28,15 @@ def __init__(self, if expansion == 1: self.layers = nn.Sequential( - nn.Conv2d(in_channels, - in_channels, - kernel_size, - stride=stride, - padding=kernel_size // 2, - groups=in_channels, - bias=False), + nn.Conv2d( + in_channels, + in_channels, + kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=in_channels, + bias=False, + ), nn.BatchNorm2d(in_channels), nn.ReLU6(True), nn.Conv2d(in_channels, out_channels, 1, bias=False), @@ -43,13 +48,15 @@ def __init__(self, nn.Conv2d(in_channels, mid_channels, 1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU6(True), - nn.Conv2d(mid_channels, - mid_channels, - kernel_size, - stride=stride, - padding=kernel_size // 2, - groups=mid_channels, - bias=False), + nn.Conv2d( + mid_channels, + mid_channels, + kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=mid_channels, + bias=False, + ), nn.BatchNorm2d(mid_channels), nn.ReLU6(True), nn.Conv2d(mid_channels, out_channels, 1, bias=False), @@ -69,22 +76,29 @@ class MobileNetV2(nn.Module): (6, 96, 3, 1), (6, 160, 3, 2), (6, 320, 1, 1), 1280 ] - def __init__(self, - *, - in_channels: int = 3, - num_classes: int = 1000, - width_multiplier: float = 1) -> None: + def __init__( + self, + *, + in_channels: int = 3, + num_classes: int = 1000, + width_multiplier: float = 1, + ) -> None: super().__init__() - out_channels = make_divisible(self.layers[0] * width_multiplier, 8) + out_channels = make_divisible( + int(self.layers[0] * width_multiplier), + divisor=8, + ) layers = nn.ModuleList([ nn.Sequential( - nn.Conv2d(in_channels, - out_channels, - 3, - stride=2, - padding=1, - bias=False), + nn.Conv2d( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + bias=False, + ), nn.BatchNorm2d(out_channels), nn.ReLU6(True), ) @@ -92,19 +106,26 @@ def __init__(self, in_channels = out_channels for expansion, out_channels, num_blocks, strides in self.layers[1:-1]: - out_channels = make_divisible(out_channels * width_multiplier, 8) + out_channels = make_divisible( + int(out_channels * width_multiplier), + divisor=8, + ) for stride in [strides] + [1] * (num_blocks - 1): layers.append( - MobileBlockV2(in_channels, - out_channels, - 3, - stride=stride, - expansion=expansion)) + MobileBlockV2( + in_channels, + out_channels, + 3, + stride=stride, + expansion=expansion, + )) in_channels = out_channels - out_channels = make_divisible(self.layers[-1] * width_multiplier, - 8, - min_value=1280) + out_channels = make_divisible( + int(self.layers[-1] * width_multiplier), + divisor=8, + min_value=1280, + ) layers.append( nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, bias=False), @@ -120,9 +141,11 @@ def __init__(self, def reset_parameters(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, - mode='fan_out', - nonlinearity='relu') + nn.init.kaiming_normal_( + m.weight, + mode='fan_out', + nonlinearity='relu', + ) if m.bias is not None: nn.init.zeros_(m.bias) if isinstance(m, nn.Linear): diff --git a/torchpack/models/vision/shufflenetv2.py b/torchpack/models/vision/shufflenetv2.py index 8204124..d4be943 100644 --- a/torchpack/models/vision/shufflenetv2.py +++ b/torchpack/models/vision/shufflenetv2.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Dict, List, Tuple, Union +from typing import ClassVar, Dict, List import torch from torch import nn @@ -15,12 +15,15 @@ def channel_shuffle(inputs: torch.Tensor, groups: int) -> torch.Tensor: class ShuffleBlockV2(nn.Module): - def __init__(self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - *, - stride: int = 1) -> None: + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + *, + stride: int = 1, + ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -33,13 +36,15 @@ def __init__(self, if stride != 1: self.branch1 = nn.Sequential( - nn.Conv2d(in_channels, - in_channels, - kernel_size, - stride=stride, - padding=kernel_size // 2, - groups=in_channels, - bias=False), + nn.Conv2d( + in_channels, + in_channels, + kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=in_channels, + bias=False, + ), nn.BatchNorm2d(in_channels), nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), @@ -50,13 +55,15 @@ def __init__(self, nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(True), - nn.Conv2d(out_channels, - out_channels, - kernel_size, - stride=stride, - padding=kernel_size // 2, - groups=out_channels, - bias=False), + nn.Conv2d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=out_channels, + bias=False, + ), nn.BatchNorm2d(out_channels), nn.Conv2d(out_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), @@ -84,22 +91,26 @@ class ShuffleNetV2(nn.Module): 2.0: [24, (244, 4, 2), (488, 8, 2), (976, 4, 2), 2048] } - def __init__(self, - *, - in_channels: int = 3, - num_classes: int = 1000, - width_multiplier: float = 1) -> None: + def __init__( + self, + *, + in_channels: int = 3, + num_classes: int = 1000, + width_multiplier: float = 1, + ) -> None: super().__init__() out_channels = self.layers[width_multiplier][0] layers = nn.ModuleList([ nn.Sequential( - nn.Conv2d(in_channels, - out_channels, - 3, - stride=2, - padding=1, - bias=False), + nn.Conv2d( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + bias=False, + ), nn.BatchNorm2d(out_channels), nn.ReLU(True), ) @@ -130,9 +141,11 @@ def __init__(self, def reset_parameters(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, - mode='fan_out', - nonlinearity='relu') + nn.init.kaiming_normal_( + m.weight, + mode='fan_out', + nonlinearity='relu', + ) if m.bias is not None: nn.init.zeros_(m.bias) if isinstance(m, nn.Linear): diff --git a/torchpack/nn/functional/__init__.py b/torchpack/nn/functional/__init__.py index 623697b..6c32d76 100644 --- a/torchpack/nn/functional/__init__.py +++ b/torchpack/nn/functional/__init__.py @@ -1 +1 @@ -from .index import * \ No newline at end of file +from .index import * diff --git a/torchpack/nn/functional/index.py b/torchpack/nn/functional/index.py index 7b5ca7f..ddd4d6f 100644 --- a/torchpack/nn/functional/index.py +++ b/torchpack/nn/functional/index.py @@ -3,8 +3,11 @@ __all__ = ['batched_index_select'] -def batched_index_select(inputs: torch.Tensor, indices: torch.Tensor, - dim: int) -> torch.Tensor: +def batched_index_select( + inputs: torch.Tensor, + indices: torch.Tensor, + dim: int, +) -> torch.Tensor: vsizes, esizes = [], [] for k, size in enumerate(inputs.shape): if k == 0: diff --git a/torchpack/train/__init__.py b/torchpack/train/__init__.py index 765c240..8c09221 100644 --- a/torchpack/train/__init__.py +++ b/torchpack/train/__init__.py @@ -1,2 +1,2 @@ -from .trainer import * from .exception import * +from .trainer import * diff --git a/torchpack/train/exception.py b/torchpack/train/exception.py index b07f358..d7fd99b 100644 --- a/torchpack/train/exception.py +++ b/torchpack/train/exception.py @@ -2,7 +2,6 @@ class StopTraining(Exception): - """ - An exception thrown to stop training. - """ + """An exception thrown to stop training.""" + pass diff --git a/torchpack/train/summary.py b/torchpack/train/summary.py index dce1da6..a369014 100644 --- a/torchpack/train/summary.py +++ b/torchpack/train/summary.py @@ -1,18 +1,24 @@ +import typing from collections import defaultdict, deque -from typing import Any, Deque, Iterable, Optional, Tuple, Union +from typing import Any, Deque, Dict, Iterable, Optional, Tuple, Union import numpy as np import torch -from ..callbacks import SummaryWriter -from ..utils.typing import Trainer +from torchpack.callbacks import SummaryWriter + +if typing.TYPE_CHECKING: + from torchpack.train import Trainer +else: + Trainer = None __all__ = ['Summary'] class Summary: + def __init__(self) -> None: - self.history = defaultdict(deque) + self.history: Dict[str, Deque[Tuple[int, Any]]] = defaultdict(deque) def set_trainer(self, trainer: Trainer) -> None: self.trainer = trainer @@ -24,11 +30,13 @@ def _set_trainer(self, trainer: Trainer) -> None: if isinstance(callback, SummaryWriter): self.writers.append(callback) - def add_scalar(self, - name: str, - scalar: Union[int, float, np.integer, np.floating], - *, - max_to_keep: Optional[int] = None) -> None: + def add_scalar( + self, + name: str, + scalar: Union[int, float, np.integer, np.floating], + *, + max_to_keep: Optional[int] = None, + ) -> None: if isinstance(scalar, np.integer): scalar = int(scalar) if isinstance(scalar, np.floating): @@ -36,8 +44,13 @@ def add_scalar(self, assert isinstance(scalar, (int, float)), type(scalar) self._add_scalar(name, scalar, max_to_keep=max_to_keep) - def _add_scalar(self, name: str, scalar: Union[int, float], *, - max_to_keep: Optional[int]) -> None: + def _add_scalar( + self, + name: str, + scalar: Union[int, float], + *, + max_to_keep: Optional[int], + ) -> None: self.history[name].append((self.trainer.global_step, scalar)) while max_to_keep is not None and \ len(self.history[name]) > max_to_keep: @@ -45,11 +58,13 @@ def _add_scalar(self, name: str, scalar: Union[int, float], *, for writer in self.writers: writer.add_scalar(name, scalar) - def add_image(self, - name: str, - tensor: Union[np.ndarray, torch.Tensor], - *, - max_to_keep: Optional[int] = None) -> None: + def add_image( + self, + name: str, + tensor: Union[np.ndarray, torch.Tensor], + *, + max_to_keep: Optional[int] = None, + ) -> None: if isinstance(tensor, torch.Tensor): tensor = tensor.cpu().numpy() assert isinstance(tensor, np.ndarray), type(tensor) @@ -60,8 +75,13 @@ def add_image(self, assert tensor.ndim == 3 and tensor.shape[0] in [1, 3, 4], tensor.shape self._add_image(name, tensor, max_to_keep=max_to_keep) - def _add_image(self, name: str, tensor: np.ndarray, *, - max_to_keep: Optional[int]) -> None: + def _add_image( + self, + name: str, + tensor: np.ndarray, + *, + max_to_keep: Optional[int], + ) -> None: self.history[name].append((self.trainer.global_step, tensor)) while max_to_keep is not None and \ len(self.history[name]) > max_to_keep: @@ -70,16 +90,13 @@ def _add_image(self, name: str, tensor: np.ndarray, *, writer.add_image(name, tensor) def keys(self) -> Iterable[str]: - for key in self.history.keys(): - yield key + yield from self.history.keys() def values(self) -> Iterable[Deque[Tuple[int, Any]]]: - for value in self.history.values(): - yield value + yield from self.history.values() def items(self) -> Iterable[Tuple[str, Deque[Tuple[int, Any]]]]: - for key, value in self.history.items(): - yield key, value + yield from self.history.items() def __contains__(self, key: str) -> bool: return key in self.history diff --git a/torchpack/train/trainer.py b/torchpack/train/trainer.py index cab5039..d5a77c6 100644 --- a/torchpack/train/trainer.py +++ b/torchpack/train/trainer.py @@ -3,28 +3,28 @@ from torch.utils.data import DataLoader, DistributedSampler -from ..callbacks import (Callback, Callbacks, ConsoleWriter, EstimatedTimeLeft, - JSONLWriter, MetaInfoSaver, ProgressBar, - TFEventWriter) -from ..train.exception import StopTraining -from ..train.summary import Summary -from ..utils import humanize -from ..utils.logging import logger +from torchpack.callbacks import (Callback, Callbacks, ConsoleWriter, + EstimatedTimeLeft, JSONLWriter, MetaInfoSaver, + ProgressBar, TFEventWriter) +from torchpack.train.exception import StopTraining +from torchpack.train.summary import Summary +from torchpack.utils import humanize +from torchpack.utils.logging import logger __all__ = ['Trainer'] class Trainer: - """ - Base class for a trainer. - """ - def train_with_defaults(self, - dataflow: DataLoader, - *, - steps_per_epoch: Optional[int] = None, - num_epochs: int = 9999999, - callbacks: Optional[List[Callback]] = None - ) -> None: + """Base class for a trainer.""" + + def train_with_defaults( + self, + dataflow: DataLoader, + *, + steps_per_epoch: Optional[int] = None, + num_epochs: int = 9999999, + callbacks: Optional[List[Callback]] = None, + ) -> None: if callbacks is None: callbacks = [] callbacks += [ @@ -40,12 +40,14 @@ def train_with_defaults(self, steps_per_epoch=steps_per_epoch, callbacks=callbacks) - def train(self, - dataflow: DataLoader, - *, - steps_per_epoch: Optional[int] = None, - num_epochs: int = 9999999, - callbacks: Optional[List[Callback]] = None) -> None: + def train( + self, + dataflow: DataLoader, + *, + steps_per_epoch: Optional[int] = None, + num_epochs: int = 9999999, + callbacks: Optional[List[Callback]] = None, + ) -> None: self.dataflow = dataflow if steps_per_epoch is None: steps_per_epoch = len(self.dataflow) @@ -106,7 +108,7 @@ def train(self, self.num_epochs, humanize.naturaldelta(time.perf_counter() - train_time))) except StopTraining as e: - logger.info('Training was stopped by {}.'.format(str(e))) + logger.info(f'Training was stopped by {str(e)}.') finally: self.after_train() @@ -139,9 +141,7 @@ def run_step(self, feed_dict: Dict[str, Any]) -> Dict[str, Any]: return output_dict def _run_step(self, feed_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Defines what to do in one iteration. - """ + """Define what to do in one iteration.""" raise NotImplementedError def after_step(self, output_dict: Dict[str, Any]) -> None: @@ -188,7 +188,7 @@ def state_dict(self) -> Dict[str, Any]: return state_dict def _state_dict(self) -> Dict[str, Any]: - return dict() + return {} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.epoch_num = state_dict.pop('epoch_num') diff --git a/torchpack/utils/config.py b/torchpack/utils/config.py index ffc8125..799ea0d 100644 --- a/torchpack/utils/config.py +++ b/torchpack/utils/config.py @@ -12,6 +12,7 @@ class Config(dict): + def __getattr__(self, key: str) -> Any: if key not in self: raise AttributeError(key) @@ -50,8 +51,8 @@ def update(self, other: Dict) -> None: else: self[key] = value - @multimethod - def update(self, opts: Union[List, Tuple]) -> None: + @multimethod # type: ignore [no-redef] # noqa: F811 + def update(self, opts: Union[List, Tuple]) -> None: # noqa: F811 index = 0 while index < len(opts): opt = opts[index] @@ -67,14 +68,14 @@ def update(self, opts: Union[List, Tuple]) -> None: subkeys = key.split('.') try: value = literal_eval(value) - except: + except Exception: pass for subkey in subkeys[:-1]: current = current.setdefault(subkey, Config()) current[subkeys[-1]] = value def dict(self) -> Dict[str, Any]: - configs = dict() + configs = {} for key, value in self.items(): if isinstance(value, Config): value = value.dict() diff --git a/torchpack/utils/device.py b/torchpack/utils/device.py index c24934a..66095c9 100644 --- a/torchpack/utils/device.py +++ b/torchpack/utils/device.py @@ -8,9 +8,9 @@ def parse_cuda_devices(text: str) -> List[int]: if text == '*': - return [device for device in range(torch.cuda.device_count())] + return list(range(torch.cuda.device_count())) - devices = [] + devices: List[int] = [] for device in text.split(','): device = device.strip().lower() if device == 'cpu': @@ -25,9 +25,11 @@ def parse_cuda_devices(text: str) -> List[int]: return devices -def set_cuda_visible_devices(devices: Union[str, List[int]], - *, - environ: os._Environ = os.environ) -> List[int]: +def set_cuda_visible_devices( + devices: Union[str, List[int]], + *, + environ: os._Environ = os.environ, +) -> List[int]: if isinstance(devices, str): devices = parse_cuda_devices(devices) environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, devices)) diff --git a/torchpack/utils/humanize.py b/torchpack/utils/humanize.py index 98a6144..83c3aaa 100644 --- a/torchpack/utils/humanize.py +++ b/torchpack/utils/humanize.py @@ -22,5 +22,5 @@ def naturaldelta(seconds: float) -> str: continue if values[k] > 1: unit += 's' - texts.append('{:.3g} {}'.format(values[k], unit)) + texts.append(f'{values[k]:.3g} {unit}') return ' '.join(texts) diff --git a/torchpack/utils/imp.py b/torchpack/utils/imp.py index d25d3e2..42ddbb3 100644 --- a/torchpack/utils/imp.py +++ b/torchpack/utils/imp.py @@ -6,15 +6,22 @@ __all__ = ['load_source'] -def load_source(fpath: str, *, name: Optional[str] = None) -> ModuleType: +def load_source( + fpath: str, + *, + name: Optional[str] = None, +) -> Optional[ModuleType]: if name is None: name = os.path.basename(fpath) if name.endswith('.py'): name = name[:-3] name = name.replace('.', '_') - # from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path + # from https://tinyurl.com/4d23vmm9 spec = importlib.util.spec_from_file_location(name, fpath) + if spec is None: + return None module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + if spec.loader is not None: + spec.loader.exec_module(module) # type: ignore [attr-defined] return module diff --git a/torchpack/utils/io.py b/torchpack/utils/io.py index 2c4e756..4bc5947 100644 --- a/torchpack/utils/io.py +++ b/torchpack/utils/io.py @@ -2,20 +2,19 @@ import os import pickle from contextlib import contextmanager -from typing import IO, Any, BinaryIO, Iterator, TextIO, Union +from io import TextIOWrapper +from typing import IO, Any, BinaryIO, Callable, Dict, Iterator, TextIO, Union import numpy as np -import scipy.io -import toml import torch import yaml -from . import fs +from torchpack.utils import fs __all__ = [ 'load', 'save', 'load_json', 'save_json', 'load_jsonl', 'save_jsonl', 'load_mat', 'save_mat', 'load_npy', 'save_npy', 'load_npz', 'save_npz', - 'load_pt', 'save_pt', 'load_taml', 'save_taml', 'load_yaml', 'save_yaml' + 'load_pt', 'save_pt', 'load_toml', 'save_toml', 'load_yaml', 'save_yaml' ] @@ -29,6 +28,7 @@ def file_descriptor(f: Union[str, IO], mode: str = 'r') -> Iterator[IO]: yield f finally: if opened: + assert isinstance(f, TextIOWrapper), type(f) f.close() @@ -53,10 +53,12 @@ def save_jsonl(f: Union[str, TextIO], obj: Any, **kwargs) -> None: def load_mat(f: Union[str, BinaryIO], **kwargs) -> Any: + import scipy.io return scipy.io.loadmat(f, **kwargs) def save_mat(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None: + import scipy.io scipy.io.savemat(f, obj, **kwargs) @@ -100,18 +102,20 @@ def save_pt(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None: torch.save(obj, f, **kwargs) -def load_toml(f: Union[str, TextIO], obj: Any, **kwargs) -> Any: - return toml.load(f, obj, **kwargs) +def load_toml(f: Union[str, TextIO], obj: Any) -> Any: + import toml + return toml.load(f, obj) -def save_toml(f: Union[str, TextIO], obj: Any, **kwargs) -> None: +def save_toml(f: Union[str, TextIO], obj: Any) -> None: + import toml with file_descriptor(f, mode='w') as fd: - toml.dump(obj, fd, **kwargs) + toml.dump(obj, fd) -def load_yaml(f: Union[str, TextIO], **kwargs) -> Any: +def load_yaml(f: Union[str, TextIO]) -> Any: with file_descriptor(f, mode='r') as fd: - return yaml.safe_load(fd, **kwargs) + return yaml.safe_load(fd) def save_yaml(f: Union[str, TextIO], obj: Any, **kwargs) -> None: @@ -120,7 +124,7 @@ def save_yaml(f: Union[str, TextIO], obj: Any, **kwargs) -> None: # yapf: disable -__io_registry = { +__io_registry: Dict[str, Dict[str, Callable]] = { '.json': {'load': load_json, 'save': save_json}, '.jsonl': {'load': load_jsonl, 'save': save_jsonl}, '.mat': {'load': load_mat, 'save': save_mat}, diff --git a/torchpack/utils/logging.py b/torchpack/utils/logging.py index 30d84b1..f9fd2cb 100644 --- a/torchpack/utils/logging.py +++ b/torchpack/utils/logging.py @@ -1,6 +1,10 @@ import sys +import typing -from .typing import Logger +if typing.TYPE_CHECKING: + from loguru import Logger +else: + Logger = None __all__ = ['logger'] @@ -8,10 +12,12 @@ def __get_logger() -> Logger: from loguru import logger logger.remove() - logger.add(sys.stdout, - level='DEBUG', - format=('[{time:YYYY-MM-DD HH:mm:ss.SSS}] ' - '{message}')) + logger.add( + sys.stdout, + level='DEBUG', + format=('[{time:YYYY-MM-DD HH:mm:ss.SSS}] ' + '{message}'), + ) return logger diff --git a/torchpack/utils/matching.py b/torchpack/utils/matching.py index a40db2d..f06fc41 100644 --- a/torchpack/utils/matching.py +++ b/torchpack/utils/matching.py @@ -5,6 +5,7 @@ class NameMatcher: + def __init__(self, patterns: Optional[Union[str, List[str]]]): if patterns is None: patterns = [] diff --git a/torchpack/utils/network.py b/torchpack/utils/network.py new file mode 100644 index 0000000..3e2aeb6 --- /dev/null +++ b/torchpack/utils/network.py @@ -0,0 +1,10 @@ +import socket + +__all__ = ['get_free_tcp_port'] + + +def get_free_tcp_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp: + tcp.bind(('0.0.0.0', 0)) + port = tcp.getsockname()[1] + return port diff --git a/torchpack/utils/typing.py b/torchpack/utils/typing.py index daad49a..947ad43 100644 --- a/torchpack/utils/typing.py +++ b/torchpack/utils/typing.py @@ -1,17 +1,7 @@ -import typing +from torch.optim.lr_scheduler import _LRScheduler as Scheduler +from torch.optim.optimizer import Optimizer -__all__ = ['Logger', 'Dataset', 'Optimizer', 'Scheduler', 'Trainer'] +from torchpack.datasets.dataset import Dataset +from torchpack.train import Trainer -Logger = None -Dataset = None -Optimizer = None -Scheduler = None -Trainer = None - -# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports -if typing.TYPE_CHECKING: - from loguru import Logger - from torch.optim.lr_scheduler import _LRScheduler as Scheduler - from torch.optim.optimizer import Optimizer - from torchpack.datasets.dataset import Dataset - from torchpack.train import Trainer +__all__ = ['Dataset', 'Optimizer', 'Scheduler', 'Trainer']