Skip to content

A unified ensemble framework for PyTorch to improve the performance and robustness of your deep learning model.

License

Notifications You must be signed in to change notification settings

RyotaroOKabe/Ensemble-Pytorch

This branch is 25 commits behind TorchEnsemble-Community/Ensemble-Pytorch:master.

Folders and files

NameName
Last commit message
Last commit date
Mar 17, 2021
Apr 5, 2022
Apr 5, 2022
Apr 5, 2022
Apr 5, 2022
May 2, 2022
May 7, 2021
May 7, 2021
Feb 12, 2021
Jan 16, 2022
Apr 5, 2022
Jun 3, 2021
Jun 3, 2021
May 2, 2022
Jan 16, 2022
Nov 11, 2020
May 2, 2022
Apr 5, 2022
Apr 15, 2021
Dec 3, 2021

Repository files navigation

./docs/_images/badge_small.png

github readthedocs codecov license

Ensemble PyTorch

A unified ensemble framework for pytorch to easily improve the performance and robustness of your deep learning model. Ensemble-PyTorch is part of the pytorch ecosystem, which requires the project to be well maintained.

Installation

pip install torchensemble

Example

from torchensemble import VotingClassifier  # voting is a classic ensemble strategy

# Load data
train_loader = DataLoader(...)
test_loader = DataLoader(...)

# Define the ensemble
ensemble = VotingClassifier(
    estimator=base_estimator,               # estimator is your pytorch model
    n_estimators=10,                        # number of base estimators
)

# Set the optimizer
ensemble.set_optimizer(
    "Adam",                                 # type of parameter optimizer
    lr=learning_rate,                       # learning rate of parameter optimizer
    weight_decay=weight_decay,              # weight decay of parameter optimizer
)

# Set the learning rate scheduler
ensemble.set_scheduler(
    "CosineAnnealingLR",                    # type of learning rate scheduler
    T_max=epochs,                           # additional arguments on the scheduler
)

# Train the ensemble
ensemble.fit(
    train_loader,
    epochs=epochs,                          # number of training epochs
)

# Evaluate the ensemble
acc = ensemble.evaluate(test_loader)         # testing accuracy

Supported Ensemble

Ensemble Name Type Source Code Problem
Fusion Mixed fusion.py Classification / Regression
Voting [1] Parallel voting.py Classification / Regression
Neural Forest Parallel voting.py Classification / Regression
Bagging [2] Parallel bagging.py Classification / Regression
Gradient Boosting [3] Sequential gradient_boosting.py Classification / Regression
Snapshot Ensemble [4] Sequential snapshot_ensemble.py Classification / Regression
Adversarial Training [5] Parallel adversarial_training.py Classification / Regression
Fast Geometric Ensemble [6] Sequential fast_geometric.py Classification / Regression
Soft Gradient Boosting [7] Parallel soft_gradient_boosting.py Classification / Regression

Dependencies

  • scikit-learn>=0.23.0
  • torch>=1.4.0
  • torchvision>=0.2.2

Reference

[1]Zhou, Zhi-Hua. Ensemble Methods: Foundations and Algorithms. CRC press, 2012.
[2]Breiman, Leo. Bagging Predictors. Machine Learning (1996): 123-140.
[3]Friedman, Jerome H. Greedy Function Approximation: A Gradient Boosting Machine. Annals of Statistics (2001): 1189-1232.
[4]Huang, Gao, et al. Snapshot Ensembles: Train 1, Get M For Free. ICLR, 2017.
[5]Lakshminarayanan, Balaji, et al. Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles. NIPS, 2017.
[6]Garipov, Timur, et al. Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs. NeurIPS, 2018.
[7]Feng, Ji, et al. Soft Gradient Boosting Machine. ArXiv, 2020.

Thanks to all our contributors

contributors

About

A unified ensemble framework for PyTorch to improve the performance and robustness of your deep learning model.

Resources

License

Code of conduct

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 99.9%
  • Other 0.1%