diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 495f9d96..83dcde4b 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -29,9 +29,8 @@ jobs: run: | echo "pypi_version=v$(python extract_version.py)" >> $GITHUB_OUTPUT echo "full_version=v$(python extract_version.py).$(date +'%Y%m%d')" >> $GITHUB_OUTPUT - echo "git_result=$(echo $(git tag -l "${{ steps.version.outputs.pypi_version }}*" | head -n 1))" >> $GITHUB_OUTPUT + echo "git_result=$(echo $(git tag -l "v$(python extract_version.py)*" | head -n 1))" >> $GITHUB_OUTPUT - run: | - echo $(git tag -l ${{ steps.version.outputs.full_version }}) echo $(git tag -l "${{ steps.version.outputs.pypi_version }}*" | head -n 1) echo ${{ steps.version.outputs.git_result == '' }} echo ${{ steps.version.outputs.git_result }} diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml new file mode 100644 index 00000000..8fb78f0e --- /dev/null +++ b/.github/workflows/python-test.yml @@ -0,0 +1,35 @@ +name: Python test + +on: [ push ] + +jobs: + build: + + runs-on: ubuntu-20.04 + strategy: + matrix: + python-version: [ "3.6.9", "3.7", "3.8", "3.9", "3.10", "3.11" ] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 src/lumo --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + # flake8 src/lumo --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + pytest + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 \ No newline at end of file diff --git a/LICENSE b/LICENSE index 50fff668..753842b6 100644 --- a/LICENSE +++ b/LICENSE @@ -1,17 +1,201 @@ - Copyright (C) 2020 Shandong University + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ - This program is licensed under the GNU General Public License 3.0 - (https://www.gnu.org/licenses/gpl-3.0.html). - Any derivative work obtained under this license must be licensed - under the GNU General Public License as published by the Free - Software Foundation, either Version 3 of the License, or (at your option) - any later version, if this derivative work is distributed to a third party. + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - The copyright for the program is owned by Shandong University. - For commercial projects that require the ability to distribute - the code of this program as part of a program that cannot be - distributed under the GNU General Public License, please contact + 1. Definitions. - sailist@outlook.com + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. - to purchase a commercial license. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.ch.md b/README.ch.md new file mode 100644 index 00000000..afbb7798 --- /dev/null +++ b/README.ch.md @@ -0,0 +1,185 @@ +# lumo + +[![PyPI version](https://badge.fury.io/py/lumo.svg)](https://badge.fury.io/py/lumo) +![Python-Test](https://github.com/pytorch-lumo/lumo/actions/workflows/python-test.yml/badge.svg) +[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/pytorch-lumo/lumo/blob/master/LICENSE) + + +`lumo`:轻量、可扩展、功能解耦合的 Pytorch 实验框架。 + +lumo 的设计理念: + +- 模块解耦合:所有模块可以单独作为您现在使用的框架中的一个插件使用(而不像其他框架几乎耦合在一起) +- 恰到好处的抽象:和模型相关的细节完全由使用者掌控,lumo 只封装了外部通用逻辑(而不像其他一些框架会代理模型初始化或损失迭代) +- 覆盖整个生命周期:数据集构建、模型初始化、随机种子、训练/测试...,lumo 为所有步骤提供了功能包或流程简化 +- 极强的可扩展性:从单文件到包含多个领域多个方法的项目,lumo 都可以提供舒适的使用体验。已在两个领域有复现项目的最佳实践示例(见[Related Work](#Related Work))。 + +# 如何使用 + +## 安装 + +从 pypi 或 github 主页安装最新的稳定版本: + +```bash +pip install -U lumo +pip install git+https://github.com/pytorch-lumo/lumo +``` + + +## 快速开始 + +本节包含 lumo 最常用的几个子功能,帮助使用者快速利用这些功能减少已有项目中的冗余代码,这些功能包括: + +- 一些常用功能的更优替代,如 [Params](#参数控制)(arguments 的平替),[Logger](#变量&日志记录)(logging 的平替) +- 一些训练过程中部份流程的优化,如 [Experiment](#路径管理&版本控制)(提供无重复的实验路径管理、基于 git 的版本控制),[DatasetBuilder](#数据集构建)(更快构建数据集), + +### 参数控制 + +`argparse` 的更优替代。`Params` 底层依托于 [omegaconf](https://github.com/omry/omegaconf) 和 [fire](https://github.com/google/python-fire) +,只需要简单的配置,就可以从文件、命令行中读取参数。 + +直接基于 Params 类定义参数: + +```python +# python main.py --epoch=30 --dataset=100 +from lumo import Params + +params = Params() +params.epoch = 20 +# 集成优化器参数,自带补全提示 +params.optim = params.OPTIM.create_optim('Adam', lr=0.0001, weight_decay=4e-5) +# 数据集只能从 cifar10/cifar100 中选择,且默认为 cifar10,其他的选择会报错 +params.dataset = params.choice('cifar10', 'cifar100') + +# 从命令行参数中更新 +params.from_args() +print(params.epoch) # -> 30 +print(params.dataset) # -> cifar100 + +# 保存到文件 +params.to_json('./config.json') +params.to_yaml('./config.yaml') +# 从文件中更新 +params.from_json('./config.json') +params.from_yaml('./config.yaml') +``` + +也可以通过继承、多重继承来嵌套,组合参数。即使在命令行中输入了不存在的参数,Params 也会正常读取。 + +### 变量&日志记录 + +`logging` 的更优替代。通过 Meter、Record 和 Logger,可以实现变量的记录和格式化输出。其中: + +- Meter 记录单次的值 +- Record 以一定规则归约 Meter 实例(如 mean、sum 等) +- Logger 用于代替 logging,除常用的 info、warn 等方法外,还提供了 inline 方法,可以在屏幕能单行更新(实际中,屏幕打印时间远小于训练时间,因此单行更新带来的时间开销可以忽略不计)。 + +```python +import random +import time + +from lumo import Record, Meter, Logger + +log = Logger() + +record = Record() +for idx in range(256): + meter = Meter() + meter.last.i = idx + meter.sum.acc = idx + meter.mean.loss = random.random() + + record.record(meter) + log.inline(record) # 单行更新 + time.sleep(0.5) + if idx % 50 == 0: + log.newline() + record.clear() + +log.info(record) +``` + +### 路径管理&版本控制 + +`Experiment` 主要提供路径管理,可以为每一次试验根据实验名、日期、次数等自动提供不一样的保存路径。此外,Experiment 还可以通过 hook +提供如代码版本管理、元数据记录等功能。在实验中,可以使用其子类 `SimpleExperiment` 实现大部分需求。 + +```python +from lumo import SimpleExperiment +from lumo import Params + +pm = Params() +pm.module = 'example' +pm.from_args() + +# 注册该次实验,实验名为 `pm.module` +exp = SimpleExperiment(pm.module) +# 实验开始,该方法会调用已注册的 ExpHooks,完成代码版本控制等功能。 +exp.start() + +# 小数据通过 `.test_file()` 获得路径 +fn = exp.test_file('params.json') +pm.to_json(fn) + +# 大文件通过 `.blob_file()` 获得路径(这是约定,而不是强制,大文件也可以保存到 `.test_file()` 中) +fn = exp.blob_file('checkpoint.pt') +with open(fn, 'w') as w: + w.write('write big data in blob file') + +print(exp.test_root) +print(exp.get_prop('git')) # see git commit history +exp.end() +``` + +### 数据集构建 + +![DatasetBuilder](./images/DatasetBuilder.png) + +`DatasetBuilder` 是采用有向无环图思路设计的数据集构建类,该类提供了一个恰当的抽象逻辑,避免了在一个实验里定义多个重复 Datasets 类。 + +`DatasetBuilder `将数据集的构件划分为输入-输出两阶段,同时提供 `.chain()`(序列格式)和`.zip()`(字典格式) 两种输出方式。 + +```python +from lumo import DatasetBuilder +from torchvision.transforms import transforms +import torch + +# Create a mnist-like dummy dataset +db = ( + DatasetBuilder() + .add_input("xs", torch.rand(500, 28, 28)) + .add_input("ys", torch.randint(0, 10, (500,))) + .add_idx('id') + .add_output("xs", "xs1", transforms.RandomHorizontalFlip()) + .add_output("xs", "xs2", ) + .add_output("ys", "ys") +) +# Watch dataset structure +print(db) +# Builder(flow={'::idx::': ['id'], 'xs': ['xs1', 'xs2'], 'ys': ['ys']}, sized=True, size=500, iterable=True) + +print(db[0]) +# dict_keys(['id', 'xs1', 'xs2', 'ys']) +``` + +# 更多教程 + +# Related Work + +- [image-classification](https://github.com/pytorch-lumo/image-classification): supervised/semi-supervised/self-supervised/noisy label learning on image-classfication + field. (suporrted datasets: CIFAR10/CIFAR100/STL-10/SVHN/ImageNet/tinyimagenet) +- [emotion-recognition-in-conversation](https://github.com/pytorch-lumo/emotion-recognition-in-conversation):Multimodel emotional recognition on conversation. (suporrted datasets: IEMOCAP/MELD/MOSEI) + + +# Acknowledge + + 一个人维护一个库四年,背后的动力是我持续不断的使用,感谢 lumo 陪我见证我的学术生涯。lumo 确实不一定适合所有人的习惯,但一定最适合我自己。lumo 取自 lumos,这是哈利波特里魔法杖发光的咒语。torch 是火炬,ignite 是点燃,所以 lumo 也向往着发光发热,希望 lumo 给大家带来美好的使用体验。 + +# License + +Distributed under the GNU General Public License 3.0. See [LICENSE](./LICENSE) for more information. + +# Contact + + - [sailist@outlook.com](mailto:sailist@outlook.com) + diff --git a/README.en.md b/README.en.md deleted file mode 100644 index 1d203a90..00000000 --- a/README.en.md +++ /dev/null @@ -1,161 +0,0 @@ -# lumo - -`lumo` is a light-weight library to help construct your experiment code, record your experiment results, especially in the field of deep learning. - - -## Features - -`lumo` is designed for reducing difficulty of the frequent code modification in experiments and simplify the redundant code. - -At present, `lumo` has these features: - - - Simplest code for **Hyperparameter Configuration**、**Dataset Building**、**Module Checkpoint**、**Meter and Log**. - - Include Git support and random seed management. You can **reset** and **archive** and **reimplement your experiments** by using simple console command. - - Include a **deep learning experiment code templete**. You can add any experiments with linearly increasing code complexity by using it. - - The framework follows the design paradigm of **convention over configuration**, the more you follow the convention, the more the framework will do for you. - -> Better use Pycharm. - -See [document](https://sailist.github.io/lumo/) for details. - - - -## Install -```bash -pip install lumo -``` - -or - -```bash -git clone https://github.com/sailist/lumo - -python setup.py install -``` - -### test - -``` -python -m pytest # or python3 -m pytest -``` - -> Only a part of code have unit test. - - -## Requirements - - - install lumo will automatically install three light other libraries: [fire](https://github.com/google/python-fire), [psutil](https://github.com/giampaolo/psutil), [joblib](https://github.com/joblib/joblib). - - lumo has mandatory dependencies on `pytorch`, `pandas` and `numpy`, you should manully install these before using lumo since they are usually common-used. - - lumo has an optional dependency on `GitPython` as a plugin to execute git command, you can run `pip install GitPython` to install it. - -```shell -pip install pandas numpy GitPython -``` -and then see [pytorch](https://pytorch.org/) to install torch. - - - -## Introduction - -Unlike other pytorch tools, `lumo` mainly designed for research, there are two core idea of it: - -1. Reduce repetition of your code. -2. Make all operations **recordable**, **resumable**, **analyzable**. - - -Your can click [Tutorial](https://sailist.github.io/lumo/tutorial/) to learn the basic use of this framework. After that, you can view [Cookbook](https://sailist.github.io/lumo/cookbook/) to see some details of this library. - -A suggested learning order may be: - - - Learn highly frequency used module: [Define hyperparameter(Params)](https://sailist.github.io/lumo/params)、[Record variable(Meter)](https://sailist.github.io/lumo/meter)、[Log(Logger)](/lumo/logger)、[Reshape your dataloader(DataBundler)](https://sailist.github.io/lumo/bundler) and their aggregation [Trainer](https://sailist.github.io/lumo/trainer). - - Learn how to manage/analyse your experiment by [Config](https://sailist.github.io/lumo/exp) and [Experiment](https://sailist.github.io/lumo/exp) - - Learn how to simple manage random seed by [RndManager](https://sailist.github.io/lumo/rnd) and to create your dataset elegantly by [DatasetBuilder](https://sailist.github.io/lumo/builder) - -After learning above contents, you can view [Cookbook](https://sailist.github.io/lumo/cookbook/) to learn the use of [tempelet code](https://sailist.github.io/lumo/structure) and other [details](https://sailist.github.io/lumo/details). - -You can also view another repository [lumo-implement](https://github.com/lumo/lumo-implement) to see a bigger example, it will continuously reimplement papers I interested by using the templete provided in `lumo`. - -## Examples - -Before start, maybe you'd like to see some simple examples to learn what can `lumo` do. - -### Define hyperparameters -By use `lumo.frame.Params`, you can define hyperparameters simply. See [Params](https://sailist.github.io/lumo/params) for details. -```python -from lumo import Params -params = Params() -params.batch_size = 128 -params.from_args() # from command args - ->>> python ap.py --optim.lr=0.001 --epoch=400 --dataset=cifar10 --k=12 -``` -### Record variable - -By using `lumo.frame.Meter`, you can record variable and update its average value with as little code as possible. See [Meter](https://sailist.github.io/lumo/meter) for details. - -```python -from lumo import Meter,AvgMeter - -am = AvgMeter() # use for record average -for j in range(500): - meter = Meter() - meter.percent(meter.c_) # when print, format 'c' as a percentage - meter.a = 1 - meter.b = "2" - meter.c = torch.rand(1)[0] - - meter.loss = loss_fn(...) - meter.rand = torch.rand(2) - meter.d = [4] # you can record any type of variable - meter.e = {5: "6"} - - am.update(meter) # Update current value in meter. Average value will be calculated automatic by declaration and the type of the variable. - print(am) -``` - - -## Contribute - -`lumo` will be better in the future, but there are still some lack exists currently, including: - - - **Lack of more detail guide** because of the lacking of developer's energy and time. - - **Lack more tests**. unit test only covers a part of the code. I hope I fixed all bugs during my using of it, but there is no guarantee of it. The compatibility is also unguaranteed. So, welcome to [issus](https://github.com/sailist/lumo/issues) it if you find it. - - **Lack of development experience**. So the version number may be confused. - -Thanks for all contribution. - - - -For file read/write and get/set, I designed - -- [Params], to make runtime config get/set/load/dump easily, -- [globs], a global/local/runtime environment variables manager. -- [Saver], to help you save/load/manage your checkpoints/models in one class. - -For data processing, I designed - -- [Builder], to hold nearly all dataset formats and special operations by one class, - -For managing experiments, I designed - -- [Experiment], which can - - make you build a suitable directory and file path in one place, - - make you record lightweight data, and - - help you make snapshot for your project code (based on git), which can make each result recoverable and - reproducible -- [random manager], a cross-lib(random/numpy/pytorch) random seed manager - -For log and meter variables produced during experiment, I designed - -- [Meter] to meter every thing in appropriate format, and -- [Logger] to log every thing in appropriate format. - -Finally, I designed [Trainer] to bundle all module above for deep learning experiment. - - -As you can see, These modules covered most demandings on deeplearning - -You can find what you want and click the link to quickly learn HOW TO USE it! All module is designed easy to use, it's -my principles. - - diff --git a/README.md b/README.md index 44b1d5ef..13fa0003 100644 --- a/README.md +++ b/README.md @@ -1,180 +1,165 @@ -# lumo - -`lumo`:轻量、可扩展、功能解耦合的 Pytorch 实验框架。 - -lumo 的设计理念: - -- 模块解耦合:所有模块可以单独作为您现在使用的框架中的一个插件使用(而不像其他框架几乎耦合在一起) -- 恰到好处的抽象:和模型相关的细节完全由使用者掌控,lumo 只封装了外部通用逻辑(而不像其他一些框架会代理模型初始化或损失迭代) -- 覆盖整个生命周期:数据集构建、模型初始化、随机种子、训练/测试...,lumo 为所有步骤提供了功能包或流程简化 -- 极强的可扩展性:从单文件到包含多个领域多个方法的项目,lumo 都可以提供舒适的使用体验。已在两个领域有复现项目的最佳实践示例(见[Related Work](#Related Work))。 - -# 如何使用 - -## 安装 - -从 pypi 或 github 主页安装最新的稳定版本: - -```bash -pip install -U lumo -pip install git+https://github.com/pytorch-lumo/lumo -``` - - -## 快速开始 - -本节包含 lumo 最常用的几个子功能,帮助使用者快速利用这些功能减少已有项目中的冗余代码,这些功能包括: - -- 一些常用功能的更优替代,如 [Params](#参数控制)(arguments 的平替),[Logger](#变量&日志记录)(logging 的平替) -- 一些训练过程中部份流程的优化,如 [Experiment](#路径管理&版本控制)(提供无重复的实验路径管理、基于 git 的版本控制),[DatasetBuilder](#数据集构建)(更快构建数据集), - -### 参数控制 - -`argparse` 的更优替代。`Params` 底层依托于 [omegaconf](https://github.com/omry/omegaconf) 和 [fire](https://github.com/google/python-fire) -,只需要简单的配置,就可以从文件、命令行中读取参数。 - -直接基于 Params 类定义参数: - -```python -# python main.py --epoch=30 --dataset=100 -from lumo import Params - -params = Params() -params.epoch = 20 -# 集成优化器参数,自带补全提示 -params.optim = params.OPTIM.create_optim('Adam', lr=0.0001, weight_decay=4e-5) -# 数据集只能从 cifar10/cifar100 中选择,且默认为 cifar10,其他的选择会报错 -params.dataset = params.choice('cifar10', 'cifar100') - -# 从命令行参数中更新 -params.from_args() -print(params.epoch) # -> 30 -print(params.dataset) # -> cifar100 - -# 保存到文件 -params.to_json('./config.json') -params.to_yaml('./config.yaml') -# 从文件中更新 -params.from_json('./config.json') -params.from_yaml('./config.yaml') -``` - -也可以通过继承、多重继承来嵌套,组合参数。即使在命令行中输入了不存在的参数,Params 也会正常读取。 - -### 变量&日志记录 - -`logging` 的更优替代。通过 Meter、Record 和 Logger,可以实现变量的记录和格式化输出。其中: - -- Meter 记录单次的值 -- Record 以一定规则归约 Meter 实例(如 mean、sum 等) -- Logger 用于代替 logging,除常用的 info、warn 等方法外,还提供了 inline 方法,可以在屏幕能单行更新(实际中,屏幕打印时间远小于训练时间,因此单行更新带来的时间开销可以忽略不计)。 - -```python -import random -import time - -from lumo import Record, Meter, Logger - -log = Logger() - -record = Record() -for idx in range(256): - meter = Meter() - meter.last.i = idx - meter.sum.acc = idx - meter.mean.loss = random.random() - - record.record(meter) - log.inline(record) # 单行更新 - time.sleep(0.5) - if idx % 50 == 0: - log.newline() - record.clear() - -log.info(record) -``` - -### 路径管理&版本控制 - -`Experiment` 主要提供路径管理,可以为每一次试验根据实验名、日期、次数等自动提供不一样的保存路径。此外,Experiment 还可以通过 hook -提供如代码版本管理、元数据记录等功能。在实验中,可以使用其子类 `SimpleExperiment` 实现大部分需求。 - -```python -from lumo import SimpleExperiment -from lumo import Params - -pm = Params() -pm.module = 'example' -pm.from_args() - -# 注册该次实验,实验名为 `pm.module` -exp = SimpleExperiment(pm.module) -# 实验开始,该方法会调用已注册的 ExpHooks,完成代码版本控制等功能。 -exp.start() - -# 小数据通过 `.test_file()` 获得路径 -fn = exp.test_file('params.json') -pm.to_json(fn) - -# 大文件通过 `.blob_file()` 获得路径(这是约定,而不是强制,大文件也可以保存到 `.test_file()` 中) -fn = exp.blob_file('checkpoint.pt') -with open(fn, 'w') as w: - w.write('write big data in blob file') - -print(exp.test_root) -print(exp.get_prop('git')) # see git commit history -exp.end() -``` - -### 数据集构建 - -![DatasetBuilder](./images/DatasetBuilder.png) - -`DatasetBuilder` 是采用有向无环图思路设计的数据集构建类,该类提供了一个恰当的抽象逻辑,避免了在一个实验里定义多个重复 Datasets 类。 - -`DatasetBuilder `将数据集的构件划分为输入-输出两阶段,同时提供 `.chain()`(序列格式)和`.zip()`(字典格式) 两种输出方式。 - -```python -from lumo import DatasetBuilder -from torchvision.transforms import transforms -import torch - -# Create a mnist-like dummy dataset -db = ( - DatasetBuilder() - .add_input("xs", torch.rand(500, 28, 28)) - .add_input("ys", torch.randint(0, 10, (500,))) - .add_idx('id') - .add_output("xs", "xs1", transforms.RandomHorizontalFlip()) - .add_output("xs", "xs2", ) - .add_output("ys", "ys") -) -# Watch dataset structure -print(db) -# Builder(flow={'::idx::': ['id'], 'xs': ['xs1', 'xs2'], 'ys': ['ys']}, sized=True, size=500, iterable=True) - -print(db[0]) -# dict_keys(['id', 'xs1', 'xs2', 'ys']) -``` - -# 更多教程 - -# Related Work - -- [image-classification](https://github.com/pytorch-lumo/image-classification): supervised/semi-supervised/self-supervised/noisy label learning on image-classfication - field. (suporrted datasets: CIFAR10/CIFAR100/STL-10/SVHN/ImageNet/tinyimagenet) -- [emotion-recognition-in-conversation](https://github.com/pytorch-lumo/emotion-recognition-in-conversation):Multimodel emotional recognition on conversation. (suporrted datasets: IEMOCAP/MELD/MOSEI) - - -# Acknowledge - - 一个人维护一个库四年,背后的动力是我持续不断的使用,感谢 lumo 陪我见证我的学术生涯。lumo 确实不一定适合所有人的习惯,但一定最适合我自己。lumo 取自 lumos,这是哈利波特里魔法杖发光的咒语。torch 是火炬,ignite 是点燃,所以 lumo 也向往着发光发热,希望 lumo 给大家带来美好的使用体验。 - -# License - -Distributed under the GNU General Public License 3.0. See [LICENSE](./LICENSE) for more information. - -# Contact - - - [sailist@outlook.com](mailto:sailist@outlook.com) - +# lumo + +[![PyPI version](https://badge.fury.io/py/lumo.svg)](https://badge.fury.io/py/lumo) +![Python-Test](https://github.com/pytorch-lumo/lumo/actions/workflows/python-test.yml/badge.svg) +[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lightning/blob/master/LICENSE) + +`lumo` is a light-weight library to help construct your experiment code, record your experiment results, especially in the field of deep learning. + + +## Features + +`lumo` is designed for reducing difficulty of the frequent code modification in experiments and simplify the redundant code. + +At present, `lumo` has these features: + + - Simplest code for **Hyperparameter Configuration**、**Dataset Building**、**Module Checkpoint**、**Meter and Log**. + - Include Git support and random seed management. You can **reset** and **archive** and **reimplement your experiments** by using simple console command. + - Include a **deep learning experiment code templete**. You can add any experiments with linearly increasing code complexity by using it. + - The framework follows the design paradigm of **convention over configuration**, the more you follow the convention, the more the framework will do for you. + +> Better use Pycharm. + +See [document](https://sailist.github.io/lumo/) for details. + + + +## Install +```bash +pip install lumo +``` + +or + +```bash +git clone https://github.com/sailist/lumo + +python setup.py install +``` + +### test + +``` +python -m pytest # or python3 -m pytest +``` + +> Only a part of code have unit test. + + +## Requirements + + - install lumo will automatically install three light other libraries: [fire](https://github.com/google/python-fire), [psutil](https://github.com/giampaolo/psutil), [joblib](https://github.com/joblib/joblib). + - lumo has mandatory dependencies on `pytorch`, `pandas` and `numpy`, you should manully install these before using lumo since they are usually common-used. + - lumo has an optional dependency on `GitPython` as a plugin to execute git command, you can run `pip install GitPython` to install it. + +```shell +pip install pandas numpy GitPython +``` +and then see [pytorch](https://pytorch.org/) to install torch. + + + +## Introduction + +Unlike other pytorch tools, `lumo` mainly designed for research, there are two core idea of it: + +1. Reduce repetition of your code. +2. Make all operations **recordable**, **resumable**, **analyzable**. + + +Your can click [Tutorial](https://sailist.github.io/lumo/tutorial/) to learn the basic use of this framework. After that, you can view [Cookbook](https://sailist.github.io/lumo/cookbook/) to see some details of this library. + +A suggested learning order may be: + + - Learn highly frequency used module: [Define hyperparameter(Params)](https://sailist.github.io/lumo/params)、[Record variable(Meter)](https://sailist.github.io/lumo/meter)、[Log(Logger)](/lumo/logger)、[Reshape your dataloader(DataBundler)](https://sailist.github.io/lumo/bundler) and their aggregation [Trainer](https://sailist.github.io/lumo/trainer). + - Learn how to manage/analyse your experiment by [Config](https://sailist.github.io/lumo/exp) and [Experiment](https://sailist.github.io/lumo/exp) + - Learn how to simple manage random seed by [RndManager](https://sailist.github.io/lumo/rnd) and to create your dataset elegantly by [DatasetBuilder](https://sailist.github.io/lumo/builder) + +After learning above contents, you can view [Cookbook](https://sailist.github.io/lumo/cookbook/) to learn the use of [tempelet code](https://sailist.github.io/lumo/structure) and other [details](https://sailist.github.io/lumo/details). + +You can also view another repository [lumo-implement](https://github.com/lumo/lumo-implement) to see a bigger example, it will continuously reimplement papers I interested by using the templete provided in `lumo`. + +## Examples + +Before start, maybe you'd like to see some simple examples to learn what can `lumo` do. + +### Define hyperparameters +By use `lumo.frame.Params`, you can define hyperparameters simply. See [Params](https://sailist.github.io/lumo/params) for details. +```python +from lumo import Params +params = Params() +params.batch_size = 128 +params.from_args() # from command args + +>>> python ap.py --optim.lr=0.001 --epoch=400 --dataset=cifar10 --k=12 +``` +### Record variable + +By using `lumo.frame.Meter`, you can record variable and update its average value with as little code as possible. See [Meter](https://sailist.github.io/lumo/meter) for details. + +```python +from lumo import Meter,AvgMeter + +am = AvgMeter() # use for record average +for j in range(500): + meter = Meter() + meter.percent(meter.c_) # when print, format 'c' as a percentage + meter.a = 1 + meter.b = "2" + meter.c = torch.rand(1)[0] + + meter.loss = loss_fn(...) + meter.rand = torch.rand(2) + meter.d = [4] # you can record any type of variable + meter.e = {5: "6"} + + am.update(meter) # Update current value in meter. Average value will be calculated automatic by declaration and the type of the variable. + print(am) +``` + + +## Contribute + +`lumo` will be better in the future, but there are still some lack exists currently, including: + + - **Lack of more detail guide** because of the lacking of developer's energy and time. + - **Lack more tests**. unit test only covers a part of the code. I hope I fixed all bugs during my using of it, but there is no guarantee of it. The compatibility is also unguaranteed. So, welcome to [issus](https://github.com/sailist/lumo/issues) it if you find it. + - **Lack of development experience**. So the version number may be confused. + +Thanks for all contribution. + + + +For file read/write and get/set, I designed + +- [Params], to make runtime config get/set/load/dump easily, +- [globs], a global/local/runtime environment variables manager. +- [Saver], to help you save/load/manage your checkpoints/models in one class. + +For data processing, I designed + +- [Builder], to hold nearly all dataset formats and special operations by one class, + +For managing experiments, I designed + +- [Experiment], which can + - make you build a suitable directory and file path in one place, + - make you record lightweight data, and + - help you make snapshot for your project code (based on git), which can make each result recoverable and + reproducible +- [random manager], a cross-lib(random/numpy/pytorch) random seed manager + +For log and meter variables produced during experiment, I designed + +- [Meter] to meter every thing in appropriate format, and +- [Logger] to log every thing in appropriate format. + +Finally, I designed [Trainer] to bundle all module above for deep learning experiment. + + +As you can see, These modules covered most demandings on deeplearning + +You can find what you want and click the link to quickly learn HOW TO USE it! All module is designed easy to use, it's +my principles. + + diff --git a/examples/1.trainer.py b/examples/1.trainer.py index 07b290f6..0e001b14 100644 --- a/examples/1.trainer.py +++ b/examples/1.trainer.py @@ -72,8 +72,8 @@ def test_step(self, batch, params: ParamsType = None) -> MetricType: builder = ( DatasetBuilder().add_input('xs', range(-500, 500)).add_input('ys', range(-500, 500)) .add_output('xs', 'xs').add_output('ys', 'ys') - .add_output_transform('xs', lambda x: torch.tensor([x])) - .add_output_transform('ys', lambda x: torch.tensor([x + 1])) + .set_output_transform('xs', lambda x: torch.tensor([x])) + .set_output_transform('ys', lambda x: torch.tensor([x + 1])) .random_sampler().chain() ) diff --git a/examples/data/nest_datasets.py b/examples/data/nest_datasets.py index ec5dba83..16216cef 100644 --- a/examples/data/nest_datasets.py +++ b/examples/data/nest_datasets.py @@ -20,7 +20,7 @@ class SameClass: def __init__(self, db: DatasetBuilder): self.db = db - ys = db.get_source('ys') + ys = db.inputs['ys'] cls_num = len(set(ys.tolist())) pos_cls = [] for i in range(cls_num): diff --git a/extract_version.py b/extract_version.py index b6affc58..03e03a4f 100644 --- a/extract_version.py +++ b/extract_version.py @@ -2,4 +2,3 @@ if __name__ == '__main__': print(extract_version()) - diff --git a/publish.cmd b/publish.cmd deleted file mode 100755 index cb25b361..00000000 --- a/publish.cmd +++ /dev/null @@ -1 +0,0 @@ -python -m twine upload --skip-existing dist/* \ No newline at end of file diff --git a/publish.sh b/publish.sh deleted file mode 100755 index cb25b361..00000000 --- a/publish.sh +++ /dev/null @@ -1 +0,0 @@ -python -m twine upload --skip-existing dist/* \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 8a58b5d7..0322e2f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.pytest.ini_options] minversion = "6.0" -addopts = '--cov=lumo' +#addopts = '--cov=lumo' testpaths = [ "tests", ] @@ -46,9 +46,34 @@ omit = [ 'src/lumo/core/record_backend/*', 'src/lumo/utils/memory_grab.py', 'src/lumo/data/collate.py', - 'src/lumo/utils/*', + 'src/lumo/utils/screen.py', + 'src/lumo/utils/timer.py', + 'src/lumo/utils/exithook.py', + 'src/lumo/utils/filelock.py', + 'src/lumo/utils/filelock2.py', + 'src/lumo/utils/fmt.py', + 'src/lumo/utils/hash.py', + 'src/lumo/utils/logger.py', + 'src/lumo/utils/cache.py', + 'src/lumo/utils/ast.py', + 'src/lumo/utils/memory_grab.py', ] +exclude_lines = [ + # Have to re-enable the standard pragma + "pragma: no cover", + # Don't complain about missing debug-only code:s + "def __repr__", + "if self.debug", + # Don't complain if tests don't hit defensive assertion code: + "raise AssertionError", + "raise NotImplementedError", + "AbstractMethodError", + # Don't complain if non-runnable code isn't run: + "if 0:", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] [tool.coverage.html] directory = 'coverage_html_report' \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 955f32e8..2329b936 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,8 +10,8 @@ tqdm rich gitPython PyYAML>3.13 -tensorboardX dbrecord packaging pandas -hydra-core \ No newline at end of file +hydra-core +tensorboard \ No newline at end of file diff --git a/setup.py b/setup.py index 0a8ebd89..f8f61652 100644 --- a/setup.py +++ b/setup.py @@ -30,8 +30,5 @@ def extract_version(): keywords='lumo', packages=find_packages('src'), entry_points={ - 'console_scripts': [ - 'lumo = lumo.cli.cli:main' - ] }, ) diff --git a/src/lumo/__init__.py b/src/lumo/__init__.py index 09c4bdd6..957eb07f 100644 --- a/src/lumo/__init__.py +++ b/src/lumo/__init__.py @@ -1,9 +1,10 @@ """ """ -__version__ = "0.14.5" +__version__ = "0.14.6" from .core import Params, ParamsType, MetricType, Meter, Record, TrainStage, BaseParams + from .data import DataLoader, DataModule, DatasetBuilder, LumoDataLoader, CollateBase, DataLoaderSide from .exp import SimpleExperiment, Experiment from .trainer import Trainer, TrainerParams, callbacks, RndManager diff --git a/src/lumo/analyse/condition.py b/src/lumo/analyse/condition.py index 7036b7c8..10d21d45 100644 --- a/src/lumo/analyse/condition.py +++ b/src/lumo/analyse/condition.py @@ -61,13 +61,15 @@ def __neg__(self): return self def __ge__(self, other): - assert other is not None + if other is None: + raise AssertionError() self.value = other self.op = ">=" return self def __le__(self, other): - assert other is not None + if other is None: + raise AssertionError() self.value = other self.op = "<=" return self @@ -83,7 +85,8 @@ def __ne__(self, other): return self def __gt__(self, other): - assert other is not None + if other is None: + raise AssertionError() self.value = other self.op = ">" return self diff --git a/src/lumo/cli/__init__.py b/src/lumo/cli/__init__.py deleted file mode 100644 index 988dce43..00000000 --- a/src/lumo/cli/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -""" - - 提供命令行参数控制 -""" - diff --git a/src/lumo/cli/cli.py b/src/lumo/cli/cli.py deleted file mode 100644 index 530163de..00000000 --- a/src/lumo/cli/cli.py +++ /dev/null @@ -1,89 +0,0 @@ -import fire -import os -from .functional import * -from ..exp import Experiment - -doc = """ -Usage: -# create templete directory -lumo init [dir] - -# easier way to open tensorboard -# lumo board [--logdir=] -# lumo board [--test=] # find test_name and tensorboard it -# lumo board # default open ./board - -# lumo mark - -# restore code snapshot of some test -lumo reset - -# archive code snapshot of some test -lumo archive - -# print log file -lumo log - -# print params of this test -lumo params - -# /--test=/--test_name= - -# TODO -lumo config local --k=v -lumo config global --k=v - -# get a free port -lumo port - -""" - - -class Main: - def sum(self, tid): - """ - - Args: - tid: test_name or test_root - - Returns: - - """ - from ..exp.finder import summary_experiment - summary_experiment(tid) - - def today(self): - pass - - def init(self, path): - git_init(path) - print(os.path.abspath(path)) - - def extract(self, test_root, output=None, verbose=True): - exp = Experiment.from_disk(test_root) - test_extract(test_root, output=output, verbose=verbose) - - def clone(self, arg: str, alias: str = None): - """ - if template: - git clone template_map[arg] alias - else: - git clone alias - - Args: - arg: url or template name - template: template id - alias: alias - - Returns: - - """ - if '/' not in arg: - _, path = git_clone_from_template(arg, alias) - else: - _, path = git_clone(arg, alias) - git_init(path) - - -fire.Fire(Main()) -exit(0) diff --git a/src/lumo/cli/functional/__init__.py b/src/lumo/cli/functional/__init__.py deleted file mode 100644 index c97b6540..00000000 --- a/src/lumo/cli/functional/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .init import git_init - -from .clone import git_clone, git_clone_from_template -from .extract import test_extract diff --git a/src/lumo/cli/functional/clone.py b/src/lumo/cli/functional/clone.py deleted file mode 100644 index 920f2479..00000000 --- a/src/lumo/cli/functional/clone.py +++ /dev/null @@ -1,30 +0,0 @@ -import os.path - -import git -from git import RemoteProgress -from urllib3.util import parse_url -from lumo import Logger - -log = Logger() -template_map = { - 'classify': 'https://github.com/pytorch-lumo/wsl-baselines' -} - - -def prograss(*args, **kwargs): - log.inline(*args) - - -def git_clone(url, alias=None): - if alias is None: - u = parse_url(url) - alias = u.path.split('/')[-1] - - res = git.Repo.clone_from(url, alias, progress=prograss) - log.newline() - return res, alias - - -def git_clone_from_template(template, alias=None): - url = template_map[template] - return git_clone(url, alias) diff --git a/src/lumo/cli/functional/extract.py b/src/lumo/cli/functional/extract.py deleted file mode 100644 index 464db72d..00000000 --- a/src/lumo/cli/functional/extract.py +++ /dev/null @@ -1,34 +0,0 @@ -from lumo.exp import Experiment -import os - -import zipfile - - -def test_extract(test_root, output=None, verbose=True): - exp = Experiment.from_disk(test_root) - if output is None: - output = os.path.join(os.getcwd(), f'{exp.test_name}.zip') - - z = zipfile.ZipFile(output, 'a', zipfile.ZIP_DEFLATED) - - if verbose: - print('deflate info') - for root, dirs, fs in os.walk(exp.test_root): - for f in fs: - a = os.path.join(root, f) - b = os.path.join('/', exp.exp_name, exp.test_name, root.replace(exp.test_root, 'experiment').lstrip('/'), f) - z.write(a, b - ) - if verbose: - print(f'{a} => {b}') - - if verbose: - print('deflate blob') - for root, dirs, fs in os.walk(exp.blob_root): - for f in fs: - a = os.path.join(root, f) - b = os.path.join('/', exp.exp_name, exp.test_name, root.replace(exp.blob_root, 'blob/'), f) - z.write(a, b - ) - if verbose: - print(f'{a} => {b}') diff --git a/src/lumo/cli/functional/init.py b/src/lumo/cli/functional/init.py deleted file mode 100644 index 2a9a7610..00000000 --- a/src/lumo/cli/functional/init.py +++ /dev/null @@ -1,68 +0,0 @@ -import git -from pathlib import Path - -from lumo.utils.repository import git_commit - -git_ignore = ['# Byte-compiled / optimized / DLL files', '.DS_Store', '__pycache__/', '*.py[cod]', '*$py.class', '', - '# C extensions', '*.so', '', '# Distribution / packaging', '.Python', 'build/', 'develop-eggs/', 'dist/', - 'downloads/', 'eggs/', '.eggs/', 'lib/', 'lib64/', 'parts/', 'sdist/', 'var/', 'wheels/', - 'share/python-wheels/', - '*.egg-info/', '.installed.cfg', '*.egg', 'MANIFEST', '', '# PyInstaller', - '# Usually these files are written by a python script from a template', - '# before PyInstaller builds the exe, so as to inject date/other infos into it.', '*.manifest', '*.spec', - '', - '# Installer logs', 'pip-log.txt', 'pip-delete-this-directory.txt', '', '# Unit test / coverage reports', - 'htmlcov/', - '.tox/', '.nox/', '.coverage', '.coverage.*', '.cache', 'nosetests.xml', 'coverage.xml', '*.cover', - '*.py,cover', - '.hypothesis/', '.pytest_cache/', 'cover/', '', '# Translations', '*.mo', '*.pot', '', '# Django stuff:', - '*.log', - 'local_settings.py', 'db.sqlite3', 'db.sqlite3-journal', '', '# Flask stuff:', 'instance/', - '.webassets-cache', '', - '# Scrapy stuff:', '.scrapy', '', '# Sphinx documentation', 'docs/_build/', '', '# PyBuilder', - '.pybuilder/', - 'target/', '', '# Jupyter Notebook', '.ipynb_checkpoints', '', '# IPython', 'profile_default/', - 'ipython_config.py', - '', '# pyenv', '# For a library or package, you might want to ignore these files since the code is', - '# intended to run in multiple environments; otherwise, check them in:', '# .python-version', '', - '# pipenv', - '# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.', - '# However, in case of collaboration, if having platform-specific dependencies or dependencies', - "# having no cross-platform support, pipenv may install dependencies that don't work, or not", - '# install all needed dependencies.', '#Pipfile.lock', '', - '# PEP 582; used by e.g. github.com/David-OConnor/pyflow', - '__pypackages__/', '', '# Celery stuff', 'celerybeat-schedule', 'celerybeat.pid', '', - '# SageMath parsed files', - '*.sage.py', '', '# Environments', '.env', '.venv', 'env/', 'venv/', 'ENV/', 'env.bak/', 'venv.bak/', '', - '# Spyder project settings', '.spyderproject', '.spyproject', '', '# Rope project settings', - '.ropeproject', '', - '# mkdocs documentation', '/site', '# pycharm', '.idea', '', '', '# mypy', '.mypy_cache/', '.dmypy.json', - 'dmypy.json', - '', '# Pyre type checker', '.pyre/', '', '# pytype static type analyzer', '.pytype/', '', - '# Cython debug symbols', - 'cython_debug/', '.thexp/', 'repo.json', '.expsdirs', '.idea/', '*.pth', '*.npy', - '*.ckpt', - '*.thexp.*', '*.pkl', '.cache/', '.lumo/config.json', '*.lumo.*', '*.ft', 'kk', 'temp', '.idea', '', - 'lumo_temp', - '.lumo/', '*scratch*'] - - -def check_gitignore(path): - ignore_file = Path(path).joinpath('.gitignore') - if ignore_file.exists(): - res = ignore_file.read_text().split('\n') - new_ignore = [i for i in git_ignore if i not in res] - else: - new_ignore = git_ignore - - Path(path).joinpath('.gitignore').write_text('\n'.join(new_ignore)) - - -def git_init(path=None): - if path is None: - path = '/' - - repo = git.Repo.init(path, mkdir=True) - check_gitignore(path) - git_commit(repo, branch_name=None) - return repo diff --git a/src/lumo/contrib/accelerate/data_loader.py b/src/lumo/contrib/accelerate/data_loader.py index dc82f282..f05ac57f 100644 --- a/src/lumo/contrib/accelerate/data_loader.py +++ b/src/lumo/contrib/accelerate/data_loader.py @@ -4,22 +4,7 @@ from accelerate.utils import send_to_device from lumo import LumoDataLoader -from lumo.data.loader import DataLoaderIterWrap -class DataLoaderDispatcher(_DataLoaderDispatcher, LumoDataLoader): - - def __iter__(self) -> DataLoaderIterWrap: - return super().__iter__() - - -class DataLoaderShard(_DataLoaderShard, LumoDataLoader): - - def __iter__(self): - if self.rng_types is not None: - synchronize_rng_states(self.rng_types, self.generator) - state = AcceleratorState() - for batch in LumoDataLoader.__iter__(self): - if state.distributed_type == DistributedType.TPU: - xm.mark_step() - yield batch if self.device is None else send_to_device(batch, self.device) +class DataLoaderDispatcher(_DataLoaderDispatcher): + pass diff --git a/src/lumo/contrib/data/splits.py b/src/lumo/contrib/data/splits.py index cdd1a52b..265ba8f4 100644 --- a/src/lumo/contrib/data/splits.py +++ b/src/lumo/contrib/data/splits.py @@ -48,7 +48,8 @@ def train_val_split(target, val_size=10000, train_size=None): np.random.shuffle(idx) if train_size is not None: - assert size > val_size + train_size, "should less than {}, but {}".format(size, train_size + val_size) + if size > val_size + train_size: + raise AssertionError("should less than {}, but {}".format(size, train_size + val_size)) return idx[val_size:val_size + train_size], idx[:val_size] return idx[val_size:], idx[:val_size] diff --git a/src/lumo/cli/functional/lists.py b/src/lumo/contrib/matplotlib/__init__.py similarity index 100% rename from src/lumo/cli/functional/lists.py rename to src/lumo/contrib/matplotlib/__init__.py diff --git a/src/lumo/contrib/matplotlib/curve.py b/src/lumo/contrib/matplotlib/curve.py new file mode 100644 index 00000000..0efd5938 --- /dev/null +++ b/src/lumo/contrib/matplotlib/curve.py @@ -0,0 +1,7 @@ +from matplotlib import pyplot as plt + + +def curve(xs, ys): + plt.plot(xs, ys, 'o') + plt.plot(xs, ys, '#1f77b4') + plt.grid() diff --git a/src/lumo/contrib/module/memoty_bank.py b/src/lumo/contrib/module/memoty_bank.py index 14955e12..f3da0c58 100644 --- a/src/lumo/contrib/module/memoty_bank.py +++ b/src/lumo/contrib/module/memoty_bank.py @@ -26,8 +26,8 @@ def __getitem__(self, item): @torch.no_grad() def push(self, name, value): - assert name in self.offsets - assert self[name].ndim == value.ndim + if name not in self.offsets or self[name].ndim != value.ndim: + raise AssertionError() value = value.detach() value = gather(value) diff --git a/src/lumo/contrib/torch/tensor.py b/src/lumo/contrib/torch/tensor.py index 5cb0e7cb..b183f73e 100644 --- a/src/lumo/contrib/torch/tensor.py +++ b/src/lumo/contrib/torch/tensor.py @@ -7,7 +7,9 @@ def rotate_right_angle(x: torch.Tensor, w_dim: int = 2, h_dim: int = 3, degree: int = 90): - assert degree in {90, 270, 180} + if degree not in {90, 270, 180}: + raise AssertionError() + if degree == 90: x = x.transpose(w_dim, h_dim) # 90 elif degree == 180: diff --git a/src/lumo/core/__init__.py b/src/lumo/core/__init__.py index 1f41063e..498b264a 100644 --- a/src/lumo/core/__init__.py +++ b/src/lumo/core/__init__.py @@ -1,5 +1,4 @@ from .params import * -from .metaclasses import * from .attr import Attr from . import interp from .meter import Meter diff --git a/src/lumo/core/attr.py b/src/lumo/core/attr.py index af775a1c..3c0bf610 100644 --- a/src/lumo/core/attr.py +++ b/src/lumo/core/attr.py @@ -8,45 +8,79 @@ class Attr(OrderedDict): def __setattr__(self, key: str, value): - _set_item(self, key.split('.'), value) + set_item_iterative(self, key.split('.'), value) def __setitem__(self, key, value): if not isinstance(key, str): raise TypeError('Key in attr must be str') - _set_item(self, key.split('.'), value) + set_item_iterative(self, key.split('.'), value) def __getattr__(self, key: str): try: - res = _get_item(self, key.split('.')) + res = get_item_iterative(self, key.split('.')) except KeyError: res = Attr() - _set_item(self, key.split('.'), res) + set_item_iterative(self, key.split('.'), res) return res def __getitem__(self, key): if not isinstance(key, str): raise TypeError('Key in attr must be str') - return _get_item(self, key.split('.')) + return get_item_iterative(self, key.split('.')) -def _set_item(dic, keys: List[str], value): +def safe_update_dict(src: dict, kwargs: dict, assert_type=True): + for ks, v in walk_dict(kwargs): + try: + old_v = get_item_iterative(src, ks) + if old_v is None or isinstance(old_v, type(v)): + set_item_iterative(src, ks, v) + # print(ks, v) + else: + raise TypeError(ks, type(old_v), type(v)) + except KeyError: + set_item_iterative(src, ks, v) + # print(ks, v) + return src + + +def walk_dict(dic: dict, root=None): + if root is None: + root = [] + for k, v in dic.items(): + if isinstance(v, dict): + yield from walk_dict(v, [*root, *k.split('.')]) + else: + yield [*root, *k.split('.')], v + + +def set_item_iterative(dic: dict, keys: List[str], value): if len(keys) == 1: if isinstance(value, dict): - value = dic.update(value) - OrderedDict.__setitem__(dic, keys[0], value) + for ks, v in walk_dict(value): + set_item_iterative(dic, [*keys, *ks], v) + else: + dict.__setitem__(dic, keys[0], value) else: - nex = Attr() - OrderedDict.__setitem__(dic, keys[0], nex) - _set_item(nex, keys[1:], value) + try: + nex = dict.__getitem__(dic, keys[0]) + if not isinstance(nex, dict): + raise ValueError(keys[0], nex) + # dict.__setitem__(dic, keys[0], nex) + except KeyError: + nex = dict() + dict.__setitem__(dic, keys[0], nex) + + set_item_iterative(nex, keys[1:], value) -def _get_item(dic, keys: List[str]): +def get_item_iterative(dic: dict, keys: List[str]): if len(keys) == 1: - return OrderedDict.__getitem__(dic, keys[0]) + return dict.__getitem__(dic, keys[0]) else: - nex = OrderedDict.__getitem__(dic, keys[0]) + nex = dict.__getitem__(dic, keys[0]) if isinstance(nex, dict): - return _get_item(nex, keys[1:]) + return get_item_iterative(nex, keys[1:]) else: raise KeyError(keys) diff --git a/src/lumo/core/disk.py b/src/lumo/core/disk.py index 47f60609..5f675d7c 100644 --- a/src/lumo/core/disk.py +++ b/src/lumo/core/disk.py @@ -76,7 +76,7 @@ def update_metric(self, key, value, compare=None, flush=False): if older > value: update = True else: - assert False + raise NotImplementedError() if update: dic[key] = value @@ -107,7 +107,7 @@ def update_metric_pair(self, key, value, key2, value2, compare=None, flush=False if old > value: update = True else: - assert False + raise NotImplementedError() if update: dic[key] = value diff --git a/src/lumo/core/interp.py b/src/lumo/core/interp.py index 940c7e7e..100f4454 100644 --- a/src/lumo/core/interp.py +++ b/src/lumo/core/interp.py @@ -25,7 +25,6 @@ from omegaconf import DictKeyType from lumo.core import BaseParams -from lumo.core.metaclasses import PropVar __all__ = ['Cos', 'Linear', diff --git a/src/lumo/core/metaclasses.py b/src/lumo/core/metaclasses.py deleted file mode 100644 index 7f571bd5..00000000 --- a/src/lumo/core/metaclasses.py +++ /dev/null @@ -1,108 +0,0 @@ -from abc import ABC, ABCMeta -from collections import OrderedDict - -__all__ = ['PropVar', 'OrderedPropVar', 'AttrPropVar', 'ABCPropVar'] - -from typing import List -from functools import wraps - - -def make_dicts(type_name, names: List[str], dic_type=dict): - def outer(func): - @wraps(func) - def inner(cls, *args, **kwargs): - self = func(cls) - for name in names: - object.__setattr__(self, name, dic_type()) - # setattr(self, name, dic_type()) - return self - - return inner - - type_name.__new__ = outer(type_name.__new__) - - -def make_dict(type_name, name: str, default): - def outer(func): - @wraps(func) - def inner(cls, *args, **kwargs): - self = func(cls) - setattr(self, name, default) - return self - - return inner - - type_name.__new__ = outer(type_name.__new__) - - -class PropVar(type): - """ - """ - - def __new__(cls, name, bases, attrs: dict, **kwds): - clazz = type.__new__(cls, name, bases, dict(attrs)) - make_dicts(clazz, ['_prop']) - return clazz - - -class OrderedPropVar(type): - """ - """ - - def __new__(cls, name, bases, attrs: dict, **kwds): - clazz = type.__new__(cls, name, bases, dict(attrs)) - make_dicts(clazz, ['_prop'], OrderedDict) - return clazz - - -class AttrPropVar(type): - """ - """ - - def __new__(cls, name, bases, attrs: dict, **kwds): - from .attr import Attr - clazz = type.__new__(cls, name, bases, dict(attrs)) - make_dicts(clazz, ['_prop'], Attr) - return clazz - - -class ABCPropVar(ABCMeta): - - def __new__(cls, name, bases, attrs: dict, **kwds): - from .attr import Attr - clazz = type.__new__(cls, name, bases, dict(attrs)) - make_dicts(clazz, ['_prop', '_content'], Attr) - return clazz - - -class Merge(type): - """ - 元类,用于将子类和父类共有字典,集合时,子类的覆盖行为改为合并父类的字典,集合 - - 由于用途特殊,仅识别类变量中以下划线开头的变量 - :: - class A(metaclass=Merge): - _dicts = {"1": 2, "3": 4} - - class B(A): - _dicts = {"5":6,7:8} - - print(B._dicts) - - result: - >>> {'5': 6, '3': 4, '1': 2, 7: 8} - """ - - def __new__(cls, name, bases, attrs: dict, **kwds): - for base in bases: - for key, value in base.__dict__.items(): # type:(str,Any) - if key.endswith("__"): - continue - if isinstance(value, set): - v = attrs.setdefault(key, set()) - v.update(value) - elif isinstance(value, dict): - v = attrs.setdefault(key, dict()) - v.update(value) - - return type.__new__(cls, name, bases, dict(attrs)) diff --git a/src/lumo/core/meter.py b/src/lumo/core/meter.py index 6ae1b27b..a8ef163f 100644 --- a/src/lumo/core/meter.py +++ b/src/lumo/core/meter.py @@ -10,11 +10,11 @@ import torch from lumo.utils.fmt import to_ndarray, detach, is_scalar -from lumo.core import PropVar -class Meter(metaclass=PropVar): +class Meter: def __init__(self): + self._prop = {} self._rec = {} self._avg = {} @@ -161,9 +161,9 @@ def __init__(self, item=None, gb_method=None): self.c = len(self.acc) self.cur = item if gb_method == 'max': - self.last = float('-inf') + self.last = -1e12 elif gb_method == 'min': - self.last = float('inf') + self.last = 1e12 else: self.last = 0 diff --git a/src/lumo/core/params.py b/src/lumo/core/params.py index 5889b31a..745877cd 100644 --- a/src/lumo/core/params.py +++ b/src/lumo/core/params.py @@ -1,4 +1,5 @@ import json +import os.path import sys import textwrap from pprint import pformat @@ -10,6 +11,7 @@ from omegaconf import DictConfig, OmegaConf, DictKeyType from omegaconf._utils import _ensure_container +from .attr import safe_update_dict, set_item_iterative from .raises import BoundCheckError, NewParamWarning # arange_param = namedtuple('arange_param', ['default', 'left', 'right'], defaults=[None, float('-inf'), float('inf')]) @@ -234,20 +236,24 @@ def choice(self, *choices) -> Choices: """ return Choices(choices[0], choices) + def safe_update(self, dic, assert_type=True): + self.update( + safe_update_dict(self.to_dict(), dic, assert_type=assert_type) + ) + def from_dict(self, dic: dict): - for k, v in dic.items(): - self[k] = v + self.safe_update(dic) return self def from_kwargs(self, **kwargs): return self.from_dict(kwargs) def from_json(self, file): - self.update(json.loads(Path(file).read_text())) + self.safe_update(json.loads(Path(file).read_text()), assert_type=True) return self def from_yaml(self, file): - self.update(OmegaConf.load(file)) + self.safe_update(dict(OmegaConf.load(file)), assert_type=True) return self def from_args(self, argv: list = None): @@ -255,23 +261,23 @@ def from_args(self, argv: list = None): argv = sys.argv def func(**kwargs): - if '_help' in kwargs: + if 'help' in kwargs: print(self) exit() return - if '_json' in kwargs: - self.from_json(kwargs['_json']) - return + config = kwargs.get('config') + if config is None: + config = kwargs.get('c') + if config is not None and isinstance(config, str) and os.path.exists(config): + self.from_yaml(config) + dic = {} for k, v in kwargs.items(): - # try: - # _get_item(self, k.split('.')) - # except: - # self[k] = v - _set_item(self, k.split('.'), v) + set_item_iterative(dic, k.split('.'), v) + self.safe_update(dic) - fire.Fire(func) + fire.Fire(func, command=argv) return self def from_hydra(self, config_path, config_name): diff --git a/src/lumo/core/record.py b/src/lumo/core/record.py index 914a76dc..7d1cda2a 100644 --- a/src/lumo/core/record.py +++ b/src/lumo/core/record.py @@ -2,7 +2,6 @@ from numbers import Number from . import Attr -from .metaclasses import PropVar from .meter import Meter import torch import numpy as np diff --git a/src/lumo/data/README.md b/src/lumo/data/README.md deleted file mode 100644 index 866f0af0..00000000 --- a/src/lumo/data/README.md +++ /dev/null @@ -1,5 +0,0 @@ - - - 数据集建立、划分 - - 基于 github template checkout 创建数据集模板 - - DataLoader 划分 - - transform 函数 \ No newline at end of file diff --git a/src/lumo/data/builder.py b/src/lumo/data/builder.py index 64a500b4..32f5722a 100644 --- a/src/lumo/data/builder.py +++ b/src/lumo/data/builder.py @@ -15,6 +15,95 @@ class DatasetBuilder(Dataset): + """ + A subclass of the Dataset class designed for quick and common dataset construction. + + For instance, a simple CIFAR-10 Dataset in (x, y) format can be easily created using the following code: + ```python + from torchvision.datasets.cifar import CIFAR10 + ds = CIFAR10(...) + x,y = ds[0] + ``` + However, there may be requirements to modify the format, such as (x1, x2, y), due to tricks like cutmix or multiple argument transformers. In this case, we need to extend the CIFAR10 class as follows: + ```python + class MyCIFAR10(CIFAR10): + def __getitem__(self,index): + ... + return x1, x2, y + ``` + + If we have multiple format requirements, we can either add extra arguments to the MyCIFAR10 class or create more subclasses, but both methods are redundant. + + Now, by using the DatasetBuilder, this process can be simplified to the utmost. + + To begin, we need to prepare the data source/inputs for CIFAR-10, which includes images and targets. We can register these inputs using .add_input(source, name) as follows: + + ```python + source = CIFAR10() + ds = ( + DatasetBuilder() + .add_input('xs', source.data) + .add_input('ys', source.targets) + ) + ``` + + Next, we define the outputs. If we want the output format to be (xs, ys), we can use the following code: + ```python + (ds + .add_output('xs','xs') + .add_output('ys','ys') + ) + ``` + The function `.add_output(source_name, output_name)` defines a data flow from input to output. + In `.add_output('xs', 'xs')`, the input is `source.data` named as 'xs', and the output is also named as 'xs'. + If we want to rename the output name to 'xs1', we can use .add_output('xs', 'xs1'). + + Now you can see the benefits of this approach. If you need an output format like `(xs1, xs2, ys)`, you just need to modify the code as follows: + ```python + (ds + .add_output('xs','xs1') + .add_output('xs','xs2') + .add_output('ys','ys') + ) + ``` + + Besides, you can access the index of each data by `.add_idx('name')`. + ``` + ds.add_idx('idx') + print(ds[0]) + {'idx': 0, ...others...} + ``` + + Finally, we can use transforms. Each input and output can be passed a transform parameter during definition, such as: + ```python + ds.add_input('xs', xs, transform=default_loader) + ds.add_output('xs','xs1',transform=randargument) + ds.add_output('xs','xs2',transform=weak) + ``` + + The transform defined at the input stage will only be called once when there is a corresponding output. Each output has its own transform. + That is to say, the transform execution process defined by the above code will like: + ```python + x -> default_loader -> randargument -> xs1 + \-> weak -> xs2 + ``` + + It's possible that you may be confused about the usage of output names. Commonly, the output types of `dataset[index]` are `list` or `dict`. + DatasetBuilder provides both types for everyone, and by default, the dict type is used. When defining outputs as `(xs1, xs2, ys)`, the output of `ds[index]` should be: + ```python + {'xs1': np.array, 'xs2': np.array, 'ys': 1} + ``` + + You can change the output type to list by calling `.chain()`: + ...python + ds.chain() + xs1,xs2,ys = ds[index] + ``` + + + + """ + def __init__(self): self._prop = {} @@ -269,6 +358,18 @@ def add_idx(self, name): return self def add_input(self, name: str, source, transform: SingleValueTransform = None): + """ + Register a input source with the transform (if provided). + Args: + name: source name + source: source, should be a sized object. + transform: + + + Notes: + Iterable object without `__len__` method currently are not well-tested. Be careful to use them in DatasetBuilder. + + """ assert name not in self._data, f'Source name {name} duplicated.' self._check_source(name, source) self._data[name] = source @@ -277,10 +378,20 @@ def add_input(self, name: str, source, transform: SingleValueTransform = None): def add_input_transform(self, name: str, transform: SingleValueTransform = None): assert name in self._data, f'Source {name} should be added.' - self._transforms[name] = transform - return self + warnings.warn('`add` may cause confusion, use set_input_transform ') + return self.set_input_transform(name, transform) def add_output(self, name: str, outkey: str, transform: SingleValueTransform = None): + """ + Add a data flow from inputs[name] to outputs[outkey] with the transform (if provided). + Args: + name: source name of inputs + outkey: output name of output + transform: a callable function + + Returns: + + """ assert name in self._data, f'Must have data source {name} first.' outkeys = self._outs.setdefault(name, list()) @@ -292,19 +403,39 @@ def add_output(self, name: str, outkey: str, transform: SingleValueTransform = N return self def add_output_transform(self, outkey: str, transform: SingleValueTransform = None): + """ + Add or **replace** transform of the output name. + Args: + outkey: output name. + transform: a callable function + """ assert outkey in self._outkeys, f'Output key {outkey} should be added.' - self._transforms[f'::{outkey}'] = transform - return self + warnings.warn('add may cause confusion, use set_output_transform ') + return self.set_output_transform(outkey, transform) def add_global_transform(self, transform: DictTransform): self._transforms['::global::'] = transform return self - def set_input_transform(self, name, transform: SingleValueTransform): + def set_input_transform(self, name, transform: SingleValueTransform = None): + """ + Add or **replace** transform of the input source {name}. + Args: + name: source name. + transform: a callable function + + """ self._transforms[name] = transform return self - def set_output_transform(self, outkey, transform: SingleValueTransform): + def set_output_transform(self, outkey, transform: SingleValueTransform = None): + """ + Add or **replace** transform of the output {name}. + Args: + outkey: output name. + transform: a callable function + + """ self._transforms[f'::{outkey}'] = transform return self diff --git a/src/lumo/data/datamodule.py b/src/lumo/data/datamodule.py index e57f9437..e858b1ac 100644 --- a/src/lumo/data/datamodule.py +++ b/src/lumo/data/datamodule.py @@ -2,7 +2,7 @@ from torch.utils.data import DataLoader -from lumo.core import PropVar, ParamsType, TrainStage +from lumo.core import TrainStage, ParamsType from .loader import DataLoaderType from .loader import DataLoaderSide diff --git a/src/lumo/data/loader.py b/src/lumo/data/loader.py index f6e12215..e178b0fa 100644 --- a/src/lumo/data/loader.py +++ b/src/lumo/data/loader.py @@ -5,50 +5,8 @@ from torch.utils.data import DataLoader -class DataLoaderIterWrap: - def __init__(self, iter_fn, batch_count=None): - self.iter_fn = iter_fn - self.iter = iter_fn() - self.c = 0 - self.batch_count = batch_count - - def __iter__(self): - while True: - try: - yield next(self) - except StopIteration: - break - - def __len__(self): - if self.batch_count is None: - return len(self.iter) - else: - return self.batch_count - - def __next__(self): - if self.batch_count is not None: - if self.c >= self.batch_count: - raise StopIteration() - try: - batch = next(self.iter) - except StopIteration as e: - if self.batch_count is not None: - self.iter = self.iter_fn() - batch = next(self.iter) - else: - raise e - - self.c += 1 - return batch - - class LumoDataLoader(DataLoader): - - def _wraped_iter_(self): - return DataLoaderIterWrap(super().__iter__, len(self)) - - def __iter__(self) -> DataLoaderIterWrap: - return self._wraped_iter_() + pass def summarize_loader(loader: DataLoader): diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index a052732f..6738f690 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -9,12 +9,13 @@ from lumo.decorators.process import call_on_main_process_wrap from lumo.proc import glob from lumo.proc.dist import is_dist, is_main, local_rank -from lumo.proc.path import blobroot, libhome +from lumo.proc.path import blobroot, libhome, progressroot from lumo.proc.path import exproot, local_dir from lumo.utils import safe_io as io from lumo.utils.fmt import can_be_filename from lumo.utils.logger import Logger from .base import ExpHook +from ..proc.pid import pid_hash, runtime_pid_obj def checkdir(path: Union[Path, str]): @@ -29,9 +30,19 @@ class Experiment: """ (by default), the directory structure is as following: .lumo (libroot) + - progress + - ".{pid}" -> hash + if pid exists and hash(psutil.Process) == hash in file: is run + else: is closed - experiments # (exp_root) record information (e.g., .log, params files, etc.) - {experiment-name-1} - {test-1} + # infomation + { + progress + pid_hash (for lumo.client monitor) + other_info: git, file, version_lock, etc. + } - {test-2} - {experiment-name-2} - {test-1} @@ -147,11 +158,22 @@ def blob_branch(self): val = Path(blobroot()).joinpath(self.exp_name, self.test_name) return checkdir(val) + @property + def progress_branch(self): + val = Path(progressroot()) + return checkdir(val) + @property def test_branch(self): val = self.exp_branch.joinpath(self.test_name) return checkdir(val) + def dump_progress(self, ratio: float, update_from=None): + res = {'ratio': ratio} + if update_from is None: + res['update_from'] = update_from + self.dump_info('progress', res, append=True) + def dump_info(self, key: str, info: dict, append=False, info_dir='info', set_prop=True): fn = self.test_file(f'{key}.json', info_dir) if append: @@ -250,6 +272,9 @@ def blob_file(self, filename, *args): parent = self.blob_branch.joinpath(*args) return checkdir(parent).joinpath(filename).as_posix() + def progress_file(self, filename): + return self.progress_branch.joinpath(filename).as_posix() + def blob_dir(self, *args): """ @@ -289,6 +314,7 @@ def exp_func(): @call_on_main_process_wrap def initial(self): self.add_tag(self.__class__.__name__, 'exp_type') + self.dump_progress(0) self.dump_info('execute', { 'repo': self.project_root, 'cwd': os.getcwd(), @@ -296,6 +322,14 @@ def initial(self): 'exec_bin': sys.executable, 'exec_argv': sys.argv }) + self.dump_info('pinfo', { + 'pid': os.getpid(), + 'hash': pid_hash(), + 'obj': runtime_pid_obj(), + }) + + # register progress + io.dump_text(self.test_root, self.progress_file(f'{os.getpid()}')) @call_on_main_process_wrap def start(self): @@ -304,7 +338,6 @@ def start(self): self.initial() self.set_prop('start', True) for hook in self._hooks.values(): # type: ExpHook - hook.on_start(self) return self @@ -314,6 +347,7 @@ def end(self, end_code=0, *args, **extra): return if self.get_prop('end', False): return + self.dump_progress(1) self.set_prop('end', True) for hook in self._hooks.values(): # type: ExpHook hook.on_end(self, end_code=end_code, *args, **extra) diff --git a/src/lumo/exp/finder.py b/src/lumo/exp/finder.py index 6c50daf6..28c63b47 100644 --- a/src/lumo/exp/finder.py +++ b/src/lumo/exp/finder.py @@ -30,12 +30,12 @@ def _get_exp_name(exp_path: str): def list_all(exp_root=None) -> Dict[str, List[Experiment]]: return { - _get_exp_name(exp_path): retrieval_tests(exp_path) + _get_exp_name(exp_path): retrieval_tests_from_experiment(exp_path) for exp_path in list_experiment_paths(exp_root) } -def retrieval_tests(exp_path) -> List[Experiment]: +def retrieval_tests_from_experiment(exp_path) -> List[Experiment]: return [retrieval_experiment(os.path.join(exp_path, f)) for f in os.listdir(exp_path)] diff --git a/src/lumo/proc/config.py b/src/lumo/proc/config.py index c8bea782..e3b74399 100644 --- a/src/lumo/proc/config.py +++ b/src/lumo/proc/config.py @@ -1,11 +1,15 @@ import json import os -__all__ = ['glob', 'global_config_path', 'local_config_path'] +__all__ = ['debug_mode', 'glob', 'global_config_path', 'local_config_path'] + +import tempfile +from typing import overload GLOBAL_DEFAULT = { 'home': os.path.expanduser("~/.lumo/"), 'cache_dir': os.path.expanduser("~/.cache/lumo"), + 'dev_branch': 'lumo_experiments', } @@ -48,5 +52,21 @@ def get_runtime_config(): return cfg +def debug_mode(base_dir=None, disable_git=True): + glob['exp_root'] = tempfile.mkdtemp(dir=base_dir) + glob['progress_root'] = tempfile.mkdtemp(dir=base_dir) + glob['home'] = tempfile.mkdtemp(dir=base_dir) + glob['cache_dir'] = tempfile.mkdtemp(dir=base_dir) + glob['blob_root'] = tempfile.mkdtemp(dir=base_dir) + glob['metric_root'] = tempfile.mkdtemp(dir=base_dir) + glob['HOOK_LOCKFILE'] = False + glob['HOOK_LASTCMD_DIR'] = tempfile.mkdtemp(dir=base_dir) + glob['HOOK_RECORDABORT'] = False + glob['HOOK_TIMEMONITOR'] = False + + if disable_git: + glob['HOOK_GITCOMMIT'] = False + + # A dict object contains runtime configuration. glob = get_runtime_config() diff --git a/src/lumo/proc/path.py b/src/lumo/proc/path.py index cb02f68d..acb45f3c 100644 --- a/src/lumo/proc/path.py +++ b/src/lumo/proc/path.py @@ -8,7 +8,6 @@ def home(): return os.path.expanduser("~") -@cache def cache_dir(): """ Directory to store cache files, like datasets. Can be shared for everyone. @@ -31,7 +30,6 @@ def cache_dir(): return res -@cache def libhome(): """Library home to store configs. Default is `~/.lumo`""" LIBHOME = glob.get('home', None) @@ -40,7 +38,6 @@ def libhome(): return os.path.join(home(), '.lumo') -@cache def exproot(): """Experiment root to store multiple experiments, default is `~/.lumo/experiments`""" EXP_ROOT = glob.get('exp_root', None) @@ -53,7 +50,18 @@ def exproot(): return res -@cache +def progressroot(): + """Experiment root to store multiple experiments, default is `~/.lumo/experiments`""" + PROGRESS_ROOT = glob.get('progress_root', None) + if PROGRESS_ROOT: + res = PROGRESS_ROOT + else: + res = os.path.join(libhome(), 'progress') + + os.makedirs(res, exist_ok=True) + return res + + def blobroot(): """Experiment root to store big files, default is `~/.lumo/blob`""" BLOB_ROOT = glob.get('blob_root', None) @@ -66,7 +74,6 @@ def blobroot(): return res -@cache def metricroot(): """ Only used for storing table_row instance. @@ -84,7 +91,6 @@ def metricroot(): return res -@cache def local_dir(): """ Project root, default is the parent directory of .git. diff --git a/src/lumo/proc/pid.py b/src/lumo/proc/pid.py new file mode 100644 index 00000000..174be98a --- /dev/null +++ b/src/lumo/proc/pid.py @@ -0,0 +1,20 @@ +from psutil import Process +import sys +from joblib import hash +import os + + +def runtime_pid_obj(pid=None): + if pid is None: + pid = os.getpid() + p = Process(pid) + obj = { + "pid": p.pid, "pname": p.name(), 'pstart': p.create_time(), 'argv': p.cmdline() + } + return obj + + +def pid_hash(pid_obj=None): + if pid_obj is None: + pid_obj = runtime_pid_obj() + return hash(pid_obj) diff --git a/src/lumo/trainer/callbacks.py b/src/lumo/trainer/callbacks.py index 55aa7088..e6c7a0b8 100644 --- a/src/lumo/trainer/callbacks.py +++ b/src/lumo/trainer/callbacks.py @@ -13,7 +13,7 @@ from lumo.utils.memory_grab import DeviceMem from torch.utils.data import DataLoader -from lumo.core import ParamsType, Meter, MetricType, Record, TrainStage, wrap_result +from lumo.core import Meter, MetricType, Record, TrainStage, wrap_result, ParamsType from lumo.data import DataModule from lumo.data.loader import summarize_loader, DataLoaderType from lumo.utils import fmt diff --git a/src/lumo/trainer/components.py b/src/lumo/trainer/components.py index a337c99c..17ff9e98 100644 --- a/src/lumo/trainer/components.py +++ b/src/lumo/trainer/components.py @@ -1,7 +1,8 @@ +from typing import NewType + import torch from lumo.core import Params -from lumo.core.metaclasses import make_dicts, make_dict from lumo.exp import SimpleExperiment from .factory import OptimFactory, InterpFactory @@ -15,6 +16,7 @@ def log_dir(self): @property def params_fn(self): res = self.test_file('params.yaml') + self.dump_string('params.yaml', res) return res @property @@ -36,48 +38,19 @@ def state_dict_dir(self): res = self.blob_dir('state_dict') return res - def dump_train_info(self, epoch: int): - self.dump_info('trainer', { - 'epoch': epoch - }, append=True) + def dump_train_eidx(self, eidx, epoch: int): + """ + Args: + eidx: start from 0, end at `epoch-1` + epoch: + """ + self.dump_progress((eidx + 1) / epoch, update_from='trainer') class ReimplementExperiment(TrainerExperiment): pass -# class TrainerPropVar(type): -# def __new__(cls, name, bases, attrs: dict, **kwds): -# for base in bases: -# for key, value in base.__dict__.items(): # type:(str,Any) -# if key.endswith("__"): -# continue -# if isinstance(value, set): -# v = attrs.setdefault(key, set()) -# v.update(value) -# elif isinstance(value, dict): -# v = attrs.setdefault(key, dict()) -# v.update(value) -# -# clazz = type.__new__(cls, name, bases, dict(attrs)) -# -# make_dicts(clazz, [ -# '_prop', -# '_cmp', -# '_rev_index', -# '_call_order', -# ]) -# -# make_dict(clazz, '_state_dicts', { -# 'optims': set(), -# 'models': set(), -# 'others': set(), -# 'tensor.th': set(), -# 'tensor.np': set(), -# }) -# return clazz - - class TrainerParams(Params): OPTIM = OptimFactory SCHE = INTERP = InterpFactory diff --git a/src/lumo/trainer/rnd.py b/src/lumo/trainer/rnd.py index 701211ed..b7509c28 100644 --- a/src/lumo/trainer/rnd.py +++ b/src/lumo/trainer/rnd.py @@ -27,9 +27,6 @@ def mark(self, seed: Union[int, str]): """ random.fix_seed(random.hashseed(seed)) - def int_time(self): - return int(str(time.time()).split(".")[-1]) - def shuffle(self, seed=None): """ 打乱,一般用于复现试验的时候随机一个种子 @@ -41,6 +38,6 @@ def shuffle(self, seed=None): """ if seed is None: - random.fix_seed(self.int_time()) + random.fix_seed(random.int_time()) else: random.fix_seed(seed) diff --git a/src/lumo/trainer/trainer.py b/src/lumo/trainer/trainer.py index f94d81a3..c465ad1c 100644 --- a/src/lumo/trainer/trainer.py +++ b/src/lumo/trainer/trainer.py @@ -8,29 +8,32 @@ import numpy as np import torch -from lumo.contrib.accelerate import Accelerator -from lumo.contrib.accelerate.utils import send_to_device -# overwrite send_to_device to resolve https://github.com/pytorch/pytorch/issues/83015 -# from accelerate import Accelerator -# from accelerate.utils import send_to_device - from accelerate import DistributedDataParallelKwargs +from accelerate.data_loader import DataLoaderDispatcher, DataLoaderShard from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader -from lumo.proc import glob -from lumo.core import ParamsType, TrainStage, Record, MetricType, Meter, Attr + +from lumo.contrib.accelerate import Accelerator +from lumo.contrib.accelerate.utils import send_to_device +from lumo.core import TrainStage, Record, MetricType, Meter from lumo.core.disk import TableRow, Metrics from lumo.data import DataModule -from ..contrib.accelerate.data_loader import DataLoaderDispatcher, DataLoaderShard from lumo.data.loader import DataLoaderType, DataLoaderSide from lumo.proc import dist +from lumo.proc import glob from lumo.trainer.rnd import RndManager from lumo.utils.logger import Logger +from lumo.utils.fmt import strftime from .base import _BaseTrainer -from .components import TrainerExperiment +from .components import TrainerExperiment, TrainerParams from .saver import Saver +# overwrite send_to_device to resolve https://github.com/pytorch/pytorch/issues/83015 +# from accelerate import Accelerator +# from accelerate.utils import send_to_device +ParamsType = TrainerParams + class Trainer(_BaseTrainer): """ @@ -465,18 +468,28 @@ def train(self, dm: Union[DataModule, DataLoaderType] = None, params: ParamsType params = self.params for eidx in range(params.epoch): + # update training progress + self.exp.dump_train_eidx(eidx, params.epoch) self.set_epoch_idx(eidx) + + # train loop epoch_record = self.train_epoch(loader, params, limit_global_steps=limit_global_steps) - self.set_property('record', epoch_record) - self.set_property('record', epoch_record) + + # self.set_property('record', epoch_record) + + # early stop `train_toggle` if self.train_toggle: self.set_property('early_stop', 'train toggle') self.train_toggle = False break + + # early stop by `global_steps` if limit_global_steps is not None and self.global_steps >= limit_global_steps: self.set_property('early_stop', f'meet limit_global_steps {limit_global_steps}') break + # update when train finished + self.exp.end() self.database.update_dict(dict(end=datetime.now(), finished=True), flush=True) self.database.flush() return self._prop @@ -711,7 +724,7 @@ def wait_for_everyone(self): """ self.accelerate.wait_for_everyone() - def save_model(self, is_best=False, meta_info: Union[str, dict, Attr] = None): + def save_model(self, is_best=False, meta_info: Union[str, dict] = None): info = self._build_trainer_meta_info(meta_info) val = self.saver.save_model(self.eidx, self.model_state_dict(), meta_info=info, @@ -719,7 +732,7 @@ def save_model(self, is_best=False, meta_info: Union[str, dict, Attr] = None): self.wait_for_everyone() return val - def _build_trainer_meta_info(self, meta_info: Union[str, dict, Attr] = None): + def _build_trainer_meta_info(self, meta_info: Union[str, dict] = None): info = dict() info['eidx'] = self.eidx if meta_info is not None: diff --git a/src/lumo/utils/fmt.py b/src/lumo/utils/fmt.py index 8d6b8da5..480c812f 100644 --- a/src/lumo/utils/fmt.py +++ b/src/lumo/utils/fmt.py @@ -41,6 +41,10 @@ def strftime(fmt='%y-%m-%d-%H%M%S', dateobj: datetime = None): return datetime.now().strftime(fmt) +def strptime(fmt='%y-%m-%d-%H%M%S', datestr: str = None): + return datetime.strptime(datestr, fmt) + + _invalid_fc = ( r"[+?@#$&%*()=;|,<>: +" r"\^\-\/\t\b\[\]\"]+" diff --git a/src/lumo/utils/random.py b/src/lumo/utils/random.py index 3593d437..96accece 100644 --- a/src/lumo/utils/random.py +++ b/src/lumo/utils/random.py @@ -7,10 +7,16 @@ import numpy as np import torch +import time + + +def int_time(): + return int(str(time.time()).split(".")[-1]) def hashseed(hashitem: Union[int, str]): - assert isinstance(hashitem, (int, str)) + if not isinstance(hashitem, (int, str)): + raise AssertionError() if isinstance(hashitem, str): digest = hashlib.md5(hashitem.encode(encoding='utf-8')).digest() diff --git a/src/lumo/utils/repository.py b/src/lumo/utils/repository.py index f1199466..d65378f8 100644 --- a/src/lumo/utils/repository.py +++ b/src/lumo/utils/repository.py @@ -7,11 +7,14 @@ import git from git import Repo, Commit -import io from joblib import hash from .filelock2 import Lock -LUMO_BRANCH = 'lumo_experiments' + +def dev_branch(): + from lumo.proc.config import glob + return glob.get('dev_branch', 'lumo_experiments') + _commits_map = {} @@ -32,6 +35,8 @@ def __init__(self, repo: Repo, branch: str): self.branch = branch def __enter__(self): + if self.branch == self.old_branch.name: + return if self.branch is None: return @@ -45,6 +50,9 @@ def __enter__(self): return head def __exit__(self, exc_type, exc_val, exc_tb): + if self.branch == self.old_branch.name: + return + if self.branch is None: return self.repo.head.set_reference(self.old_branch) @@ -57,31 +65,24 @@ def check_have_commit(repo): repo.index.commit('initial commit') -@lru_cache() -def load_repo(dir='./'): +def load_repo(root='./'): """ Try to load git repository object of a directory. Args: - dir: str, a directory path, default is the current working dir. + root: str, a directory path, default is the current working dir. if dir is a repository dir, then a git.Repo object will be retured. if not, some you can type a path to init it, or type '!' to cancel init it. Returns: git.Repo object or None if dir not have git repository and cancel to init it. """ - path = git_dir(dir) + path = git_dir(root) repo = Repo(path) check_have_commit(repo) return repo -def add(repo=None): - if repo is None: - repo = load_repo() - return repo.git.add(all=True) - - -def git_commit(repo=None, key=None, branch_name=LUMO_BRANCH, info: str = None, filter_files=None): +def git_commit(repo=None, key=None, branch_name=None, info: str = None, filter_files=None): """ ``` cd @@ -103,6 +104,9 @@ def git_commit(repo=None, key=None, branch_name=LUMO_BRANCH, info: str = None, f Returns: git.Commit object, see gitpython for details. """ + if branch_name is None: + branch_name = dev_branch() + try: if repo is None: repo = load_repo() @@ -110,23 +114,26 @@ def git_commit(repo=None, key=None, branch_name=LUMO_BRANCH, info: str = None, f if key is not None and key in _commits_map: return _commits_map[key] - if LUMO_BRANCH not in repo.branches: - repo.create_head(LUMO_BRANCH) - print(f'branch {LUMO_BRANCH} not found, will be created automatically.') + if branch_name not in repo.branches: + repo.create_head(branch_name) + print(f'branch {branch_name} not found, will be created automatically.') - exp_head_commit = repo.heads[LUMO_BRANCH].commit - diff = repo.active_branch.commit.diff(exp_head_commit) + diff_uncommit = repo.head.commit.diff() + exp_head_commit = repo.heads[branch_name].commit + diff_from_branches = repo.active_branch.commit.diff(exp_head_commit) + # print(diff_uncommit) if filter_files is not None: - diff = [i.a_path for i in diff if i.a_path in filter_files] + diff_from_branches = [i.a_path for i in diff_from_branches if i.a_path in filter_files] - if len(diff) == 0: + if len(diff_from_branches) == 0 and len(diff_uncommit) == 0 and len(repo.untracked_files) == 0: commit_ = exp_head_commit else: with branch(repo, branch_name): change_file = [] change_file.extend(repo.untracked_files) - change_file.extend([i.a_path for i in repo.head.commit.diff(None)]) + change_file.extend([i.a_path for i in diff_from_branches]) + change_file.extend([i.a_path for i in diff_uncommit]) # print(change_file) if filter_files is not None: print('before filter', change_file) @@ -147,22 +154,7 @@ def git_commit(repo=None, key=None, branch_name=LUMO_BRANCH, info: str = None, f return commit_ -def reset(repo=None, commit_hex=None, commit: Commit = None): - """ - 将工作目录中的文件恢复到某个commit - 恢复快照的 git 流程: - git branch experiment - git add . & git commit -m ... // 保证文件最新,防止冲突报错,这一步由 Experiment() 代为完成 - git checkout // 恢复文件到 - git checkout -b reset // 将当前状态附到新的临时分支 reset 上 - git branch experiment // 切换回 experiment 分支 - git add . & git commit -m ... // 将当前状态重新提交到最新 - // 此时experiment 中最新的commit 为恢复的 - git branch -D reset // 删除临时分支 - git branch master // 最终回到原来分支,保证除文件变动外git状态完好 - Returns: - An Experiment represents this reset operation - """ +def git_checkout(repo=None, commit_hex=None, commit: Commit = None): if repo is None: repo = load_repo() @@ -172,44 +164,41 @@ def reset(repo=None, commit_hex=None, commit: Commit = None): old_path = os.getcwd() os.chdir(commit.tree.abspath) - with branch(commit.repo, LUMO_BRANCH) as new_branch: - repo.git.checkout(commit.hexsha) - repo.git.checkout('-b', 'reset') - repo.head.reference = new_branch - _ = git_commit(repo, branch_name=repo.head.reference.name, info="Reset from {}".format(commit.hexsha)) - repo.git.branch('-d', 'reset') + # with branch(commit.repo, LUMO_BRANCH) as new_branch: + repo.git.checkout('-b', commit.hexsha[:8], commit.hexsha) os.chdir(old_path) - return None + return commit.hexsha[:8] -def archive(repo=None, commit_hex=None, commit: Commit = None, tgt=None): +def git_archive(repo=None, commit_hex=None, commit: Commit = None): """ - TODO - 将某次 test 对应 commit 的文件打包,相关命令为 - git archive -o + git archive -o + Returns: An Experiment represents this archive operation """ + from lumo.exp import Experiment if repo is None: repo = load_repo() + if commit is None and commit_hex is not None: + commit = repo.commit(commit_hex) + old_path = os.getcwd() os.chdir(commit.tree.abspath) - # exp = Experiment('Archive') - - # revert_path = checkpath(cache_dir(), 'archives', commit) - # revert_fn = os.path.join(revert_path, "code.zip") + exp = Experiment('GitArchive') + fn = exp.blob_file(f'{commit.hexsha[:8]}.tar') - # TODO 在code.zip目录下添加相关说明 - # exp.add_plugin('archive', {'file': revert_fn, - # 'test_name': self.name}) - # with open(revert_fn, 'wb') as w: - # repo.archive(w, commit) + exp.dump_info('git_archive', {'file': fn, + 'test_name': exp.test_name, + 'commit_hex': commit.hexsha[:8]}) + exp.dump_string('archive_fn', fn) + with open(fn, 'wb') as w: + repo.archive(w, commit.hexsha) - # exp.end() os.chdir(old_path) - return None + return exp @lru_cache(1) @@ -248,19 +237,9 @@ def git_dir(root='./'): else: return None - -def get_tree_from_commit(commit: Commit, tree=None): - if tree is None: - tree = commit.tree - yield tree.abspath, tree.blobs, tree.trees - for tree in tree.trees: - yield from get_tree_from_commit(commit, tree) - - -def get_diff_tree_from_commits(): - pass - - -def get_file_of_commit(commit: Commit, file_name) -> bytes: - blob = commit.tree / file_name - return blob.data_stream.read() +# def get_tree_from_commit(commit: Commit, tree=None): +# if tree is None: +# tree = commit.tree +# yield tree.abspath, tree.blobs, tree.trees +# for tree in tree.trees: +# yield from get_tree_from_commit(commit, tree) diff --git a/src/lumo/utils/screen.py b/src/lumo/utils/screen.py index b2c346a6..fb41dc4b 100644 --- a/src/lumo/utils/screen.py +++ b/src/lumo/utils/screen.py @@ -25,6 +25,7 @@ def _is_jupyter() -> bool: # pragma: no cover try: get_ipython # type: ignore except NameError: + get_ipython = lambda: () return False ipython = get_ipython() # type: ignore shell = ipython.__class__.__name__ diff --git a/tests/core/test_attr.py b/tests/core/test_attr.py index 51f3aaaf..0d250769 100644 --- a/tests/core/test_attr.py +++ b/tests/core/test_attr.py @@ -1,4 +1,4 @@ -from lumo.core.attr import Attr as attr +from lumo.core.attr import Attr as attr, set_item_iterative, get_item_iterative import numpy as np import torch @@ -27,6 +27,20 @@ def get_res(): def test_replace(): res = get_res() - res.update(a=6, b=7) + res.update(a=6, b=[4, 5]) + res['c.c.e.f'] = 5 assert res.a == 6 - assert res.b == 7 + assert res.b == [4, 5] + assert res['c.c.e.f'] == 5 + assert res['c.a'] == 1 + assert res['c.b'] == [5, 6, 7] + assert isinstance(res['c.c.e'], dict) + + +def test_get_set(): + res = {} + set_item_iterative(res, ['a', 'b', 'c'], 4) + assert isinstance(res['a'], dict) + assert isinstance(res['a']['b'], dict) + assert res['a']['b']['c'] == 4 + # set_item_iterative(res, '') diff --git a/tests/core/test_meta_classes.py b/tests/core/test_meta_classes.py deleted file mode 100644 index 958b32f2..00000000 --- a/tests/core/test_meta_classes.py +++ /dev/null @@ -1,17 +0,0 @@ -""" - -""" - -from lumo.core.metaclasses import Merge - - -def test_merge(): - class A(metaclass=Merge): - _item = {1: 2, 3: 4} - - class B(A): - _item = {5: 6, 7: 8} - - b = B() - assert 1 in b._item and 3 in b._item and 5 in b._item and 7 in b._item - assert 1 in B._item and 3 in B._item and 5 in B._item and 7 in B._item diff --git a/tests/core/test_params.py b/tests/core/test_params.py index 407005a4..5510a78f 100644 --- a/tests/core/test_params.py +++ b/tests/core/test_params.py @@ -1,3 +1,7 @@ +import json +import tempfile +from omegaconf import DictConfig + from lumo.core.raises import BoundCheckError from lumo import BaseParams @@ -42,6 +46,15 @@ def get_res(): return res +def test_argv(): + params = get_res() + params.from_args(['--a', '1', '--d.c.d=2']) + assert params.a == 1 + assert params.d.c.d == 2 + assert isinstance(params.kk, DictConfig) + assert isinstance(params.d.c, DictConfig) + + def test_dict(): res = get_res() jsn = res.to_dict() @@ -49,6 +62,15 @@ def test_dict(): assert rres.hash() == res.hash() +def test_json(): + res = get_res() + fn = tempfile.mktemp() + with open(fn, 'w') as w: + json.dump({'c': {'a': 2}}, w) + res.from_json(fn) + assert res.c.a == 2 + + def test_copy(): res = get_res() copy = res.copy() diff --git a/src/lumo/cli/functional/tune.py b/tests/data/__init__.py similarity index 100% rename from src/lumo/cli/functional/tune.py rename to tests/data/__init__.py diff --git a/src/lumo/cli/functional/summary.py b/tests/data/test_side.py similarity index 66% rename from src/lumo/cli/functional/summary.py rename to tests/data/test_side.py index b28b04f6..139597f9 100644 --- a/src/lumo/cli/functional/summary.py +++ b/tests/data/test_side.py @@ -1,3 +1,2 @@ - diff --git a/src/lumo/cli/functional/watch.py b/tests/exp/__init__.py similarity index 100% rename from src/lumo/cli/functional/watch.py rename to tests/exp/__init__.py diff --git a/tests/trainer/test_builder.py b/tests/trainer/test_builder.py index 94327d3b..5e01ff16 100644 --- a/tests/trainer/test_builder.py +++ b/tests/trainer/test_builder.py @@ -1,4 +1,4 @@ -from lumo import DatasetBuilder +from lumo import DatasetBuilder, DataLoaderSide def global_check(dic): @@ -15,8 +15,8 @@ def create_dataset_builder(): .add_output(name='xs', outkey='xs1') .add_output(name='xs', outkey='xs2') .add_output(name='ys', outkey='ys1') - .add_output_transform('xs1', lambda x: x + 1) - .add_output_transform('ys1', lambda x: x - 1) + .set_output_transform('xs1', lambda x: x + 1) + .set_output_transform('ys1', lambda x: x - 1) .add_global_transform(global_check) ) return builder @@ -56,3 +56,28 @@ def test_builder_base(): assert 'ys1' in dic str(sub_builder) + + +def test_side(): + sup = create_dataset_builder() + un = create_dataset_builder() + + dl = ( + DataLoaderSide() + .add('sup', sup.DataLoader(batch_size=128, drop_last=True), cycle=True) + .add('un', un.DataLoader(batch_size=32, drop_last=True)) + .zip() + ) + + assert len(dl) == len(un) // 32 + + for batch in dl: + assert isinstance(batch, dict) + sup, un = batch['sup'], batch['un'] + assert sup + + assert sup['xs1'].shape[0] == 128 + assert 'xs1' in sup + assert 'xs2' in sup + assert 'ys1' in sup + assert un['xs1'].shape[0] == 32 diff --git a/tests/trainer/test_finder.py b/tests/trainer/test_finder.py new file mode 100644 index 00000000..2588bae8 --- /dev/null +++ b/tests/trainer/test_finder.py @@ -0,0 +1,41 @@ +import random + +from lumo import Trainer, ParamsType, TrainerParams, Experiment +from lumo.exp import finder +from lumo.proc.config import debug_mode + + +class ATrainer(Trainer): + + def icallbacks(self, params: ParamsType): + super().icallbacks(params) + + +class BTrainer(Trainer): + + def icallbacks(self, params: ParamsType): + super().icallbacks(params) + + +def test_finder(): + debug_mode() + + for i in range(5): + params = TrainerParams() + params.epoch = i + params.rnd = random.random() + ATrainer(params).train() + BTrainer(params).train() + + all_tests = finder.list_all() + assert len(all_tests) == 2 + assert ATrainer.generate_exp_name() in all_tests + assert BTrainer.generate_exp_name() in all_tests + assert len(all_tests[ATrainer.generate_exp_name()]) == 5 + assert len(all_tests[BTrainer.generate_exp_name()]) == 5 + + assert isinstance(all_tests[ATrainer.generate_exp_name()][0], Experiment) + for exp in all_tests[ATrainer.generate_exp_name()]: + params = TrainerParams().from_yaml(exp.properties['params.yaml']) + assert params.hash() == exp.properties['params_hash'] + assert finder.find_path_from_test_name(exp.test_name) == exp.test_root diff --git a/tests/trainer/test_skip.py b/tests/trainer/test_skip.py index 7176bbb1..966af2ac 100644 --- a/tests/trainer/test_skip.py +++ b/tests/trainer/test_skip.py @@ -3,6 +3,8 @@ import tempfile import torch + +from lumo.proc.config import debug_mode from lumo.proc.path import cache_dir from torch import nn from torch.utils.data import DataLoader @@ -22,8 +24,8 @@ def create_dataset_builder(): .add_output(name='xs', outkey='xs1') .add_output(name='xs', outkey='xs2') .add_output(name='ys', outkey='ys1') - .add_output_transform('xs1', lambda x: x + 1) - .add_output_transform('ys1', lambda x: x - 1) + .set_output_transform('xs1', lambda x: x + 1) + .set_output_transform('ys1', lambda x: x - 1) ) return builder @@ -121,16 +123,7 @@ def test_callback(): params = MyParams() params.epoch = 2 - glob['exp_root'] = tempfile.mkdtemp(dir=cache_dir()) - glob['blob_root'] = tempfile.mkdtemp(dir=cache_dir()) - glob['metric_root'] = tempfile.mkdtemp(dir=cache_dir()) - glob['HOOK_LOCKFILE'] = False - glob['HOOK_LASTCMD_DIR'] = tempfile.mkdtemp(dir=cache_dir()) - glob['HOOK_GITCOMMIT'] = False - glob['HOOK_RECORDABORT'] = False - glob['HOOK_DIARY'] = False - glob['HOOK_TIMEMONITOR'] = False - # glob['HOOK_FINALREPORT'] = False + debug_mode() trainer = CBTrainer(params, dm=MyDataModule()) trainer.train() trainer.test() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 34cf0927..35415bbb 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1,5 +1,9 @@ from typing import Union, Optional, Sequence, Mapping, Any +from lumo.proc.config import debug_mode +from lumo.utils.repository import git_dir +import os + import tempfile import torch @@ -21,8 +25,8 @@ def create_dataset_builder(): .add_output(name='xs', outkey='xs1') .add_output(name='xs', outkey='xs2') .add_output(name='ys', outkey='ys1') - .add_output_transform('xs1', lambda x: x + 1) - .add_output_transform('ys1', lambda x: x - 1) + .set_output_transform('xs1', lambda x: x + 1) + .set_output_transform('ys1', lambda x: x - 1) ) return builder @@ -132,19 +136,11 @@ def idataloader(self, params: ParamsType = None, stage: TrainStage = None): self.regist_dataloader_with_stage(stage, dl) -def test_callback(): +def test_trainer(): params = TrainerParams() params.epoch = 2 - glob['exp_root'] = tempfile.mkdtemp() - glob['blob_root'] = tempfile.mkdtemp() - glob['metric_root'] = tempfile.mkdtemp() - glob['HOOK_LOCKFILE'] = False - glob['HOOK_LASTCMD_DIR'] = tempfile.mkdtemp() - glob['HOOK_GITCOMMIT'] = False - glob['HOOK_RECORDABORT'] = False - glob['HOOK_DIARY'] = False - glob['HOOK_TIMEMONITOR'] = False + debug_mode() # glob['HOOK_FINALREPORT'] = False trainer = CBTrainer(params, dm=MyDataModule()) trainer.train() @@ -153,6 +149,12 @@ def test_callback(): trainer.logger.info(trainer.lf.functions) trainer.exp.end() + # test trainer experiment + exp = trainer.exp + assert exp.exp_root == os.path.join(glob['exp_root'], trainer.generate_exp_name()) + assert exp.lib_root == glob['home'] + assert exp.blob_root == os.path.join(glob['blob_root'], trainer.generate_exp_name(), exp.test_name) + assert exp.project_root == git_dir() # how to test writer? _ = trainer.safe_writer @@ -161,4 +163,4 @@ def test_callback(): if __name__ == '__main__': - test_callback() + test_trainer() diff --git a/tests/utils/test_random.py b/tests/utils/test_random.py index 279f460e..f1493c68 100644 --- a/tests/utils/test_random.py +++ b/tests/utils/test_random.py @@ -13,14 +13,14 @@ def test_device(): d = torch.rand(10, device='cuda') assert (a == d).all() - if torch.has_mps: - # [2023.02.22] Currently (as MPS support is quite new) there is no way to set the seed for MPS directly. - # fix_seed(1) - # a = torch.rand(10, device='mps') - # fix_seed(1) - # d = torch.rand(10, device='mps') - # assert (a == d).all() - pass + # if torch.has_mps: + # [2023.02.22] Currently (as MPS support is quite new) there is no way to set the seed for MPS directly. + # fix_seed(1) + # a = torch.rand(10, device='mps') + # fix_seed(1) + # d = torch.rand(10, device='mps') + # assert (a == d).all() + pass def test_state(): diff --git a/tests/utils/test_repository.py b/tests/utils/test_repository.py new file mode 100644 index 00000000..6956bda0 --- /dev/null +++ b/tests/utils/test_repository.py @@ -0,0 +1,71 @@ +import tempfile +import os +import git + +from lumo.proc.config import debug_mode +from lumo.utils import repository +import random + + +def write(fn): + with open(fn, 'w') as w: + st = str(random.random()) + w.write(st) + return st + + +def read(fn): + with open(fn) as w: + return w.read() + + +def test_git(): + debug_mode() + git_root = tempfile.mkdtemp() + old_root = os.getcwd() + os.chdir(git_root) + repo = git.Repo.init(git_root) + f_str = write('init.txt') + repo.index.add(['init.txt']) + repo.index.commit('initial commit') + main_branch = repo.active_branch.name + + repository.git_commit(repo) + # untracked_files + a_str = write('a.txt') + a_hash = repository.git_commit(repo) + # uncommitted changes + b_str = write('a.txt') + # untracked_files + bb_str = write('b.txt') + b_hash = repository.git_commit(repo) + + c_str = write('a.txt') + # committed changes + c_hash = repository.git_commit(repo, branch_name=main_branch) + d_hash = repository.git_commit(repo, branch_name=main_branch) + cc_hash = repository.git_commit(repo) + assert c_hash == d_hash + old_branch_name = repository.git_checkout(repo, a_hash) + + assert repo.active_branch.name == old_branch_name + assert read('a.txt') == a_str + old_branch_name = repository.git_checkout(repo, b_hash) + assert repo.active_branch.name == old_branch_name + assert read('a.txt') == b_str + assert read('b.txt') == bb_str + + old_branch_name = repository.git_checkout(repo, cc_hash) + assert repo.active_branch.name == old_branch_name + assert read('a.txt') == c_str + assert read('b.txt') == bb_str + + import tarfile + + exp = repository.git_archive(repo, b_hash) + archived_fn = exp.load_string('archive_fn') + file = tarfile.open(archived_fn, mode='r') + assert file.extractfile('a.txt').read().decode() == b_str + assert file.extractfile('init.txt').read().decode() == f_str + + os.chdir(old_root)