-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
18 changed files
with
1,530 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*pyc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,95 @@ | ||
# Hibou: A Family of Foundational Vision Transformers for Histopathology | ||
|
||
# Hibou: A Family of Foundational Vision Transformers for Pathology | ||
|
||
[https://arxiv.org/abs/2406.05074](https://arxiv.org/abs/2406.05074) | ||
|
||
## Introduction | ||
Model code, weights and paper coming soon... | ||
|
||
This repository contains the code to run the Hibou-B model locally. For inquiries about accessing Hibou-L, please contact us at [[email protected]](mailto:[email protected]). | ||
|
||
## Getting Started | ||
|
||
### Using HuggingFace | ||
|
||
The easiest way to use the Hibou-B model is through the HuggingFace repository. Run the following code to get started: | ||
|
||
```python | ||
from transformers import AutoImageProcessor, AutoModel | ||
|
||
processor = AutoImageProcessor.from_pretrained("histai/hibou-b", trust_remote_code=True) | ||
model = AutoModel.from_pretrained("histai/hibou-b", trust_remote_code=True) | ||
``` | ||
|
||
We use a customized implementation of the DINOv2 architecture from the transformers library to add support for registers, which requires the `trust_remote_code=True` flag. | ||
|
||
### Using the Model Directly | ||
|
||
If you prefer to use the model without the transformers library, follow these steps: | ||
|
||
1. **Install the requirements and the package:** | ||
|
||
```bash | ||
git clone https://github.com/HistAI/hibou.git | ||
cd hibou | ||
pip install -r requirements.txt && pip install -e . | ||
``` | ||
|
||
2. **Download the model weights:** | ||
|
||
[Hibou-B Weights](https://drive.google.com/file/d/12ICd_-yJWMYYo5OskMmc9SHJAQivAtS7/view?usp=sharing) | ||
|
||
3. **Load the model with the following code:** | ||
|
||
```python | ||
from hibou import build_model | ||
model = build_model("weights-path") | ||
``` | ||
|
||
### Example Notebook | ||
|
||
For more information, refer to the [example.ipynb](example.ipynb) notebook. | ||
|
||
## Metrics | ||
**Table: Linear probing benchmarks reporting top-1 accuracy.** | ||
|
||
*Metrics for Virchow and RudolfV are derived from the respective papers, as these models are not open-sourced.* | ||
|
||
| Dataset | Phikon | Kaiko-B8 | Virchow* | RudolfV* | Prov-GigaPath | Hibou-B | Hibou-L | | ||
|-----------|--------|----------|----------|----------|---------------|---------|---------| | ||
| CRC-100K | 0.917 | 0.949 | 0.968* | **0.973*** | 0.968 | 0.955 | 0.966 | | ||
| PCAM | 0.916 | 0.919 | 0.933* | 0.944* | **0.947** | 0.946 | 0.943 | | ||
| MHIST | 0.791 | 0.832 | 0.834* | 0.821* | 0.839 | 0.812 | **0.849** | | ||
| MSI-CRC | 0.750 | 0.786 | - | 0.755* | 0.771 | 0.779 | **0.797** | | ||
| MSI-STAD | 0.760 | 0.814 | - | 0.788* | 0.784 | 0.797 | **0.825** | | ||
| TIL-DET | 0.944 | **0.945** | - | 0.943* | 0.939 | 0.942 | 0.943 | | ||
| **AVG (1-3)** | 0.875 | 0.900 | 0.912 | 0.913 | 0.918 | 0.904 | **0.919** | | ||
| **AVG (1-6)** | 0.846 | 0.874 | - | 0.871 | 0.875 | 0.872 | **0.887** | | ||
|
||
|
||
## License | ||
|
||
This repository is licensed under the Apache License, Version 2.0. See the [LICENSE](LICENSE) file for the full license text. | ||
|
||
## Acknowledgements | ||
|
||
We would like to thank the authors of the DINOv2 repository, upon which this project is built. The original repository can be found [here](https://github.com/facebookresearch/dinov2). | ||
|
||
--- | ||
|
||
Feel free to reach out at [[email protected]](mailto:[email protected]) if you have any questions or need further assistance! | ||
|
||
## Citation | ||
|
||
If you use our work, please cite: | ||
|
||
``` | ||
@misc{nechaev2024hibou, | ||
title={Hibou: A Family of Foundational Vision Transformers for Pathology}, | ||
author={Dmitry Nechaev and Alexey Pchelnikov and Ekaterina Ivanova}, | ||
year={2024}, | ||
eprint={2406.05074}, | ||
archivePrefix={arXiv}, | ||
primaryClass={eess.IV} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ec86b098", | ||
"metadata": {}, | ||
"source": [ | ||
"# Hibou Model Usage Example\n", | ||
"\n", | ||
"This notebook showcases the basic usage of the Hibou model. The minimal installation for this notebook to work should be:\n", | ||
"```bash\n", | ||
"pip install torch torchvision opencv-python matplotlib\n", | ||
"```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"In this notebook the basic usage of the Hibou model is showcased.\n", | ||
"\n", | ||
"The minimal installation for this notebook to work should be: `pip install torch torchvision opencv-python matplotlib`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch, torchvision\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"from PIL import Image\n", | ||
"\n", | ||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Load the test image." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"image = Image.open(\"images/sample.png\").convert(\"RGB\")\n", | ||
"plt.imshow(image)\n", | ||
"plt.axis('off')\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## HuggingFace Hub Example" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from transformers import AutoImageProcessor, AutoModel\n", | ||
"\n", | ||
"processor = AutoImageProcessor.from_pretrained(\"histai/hibou-b\", trust_remote_code=True)\n", | ||
"hf_model = AutoModel.from_pretrained(\"histai/hibou-b\", trust_remote_code=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"hf_data = processor(images=image, return_tensors=\"pt\").to(device)\n", | ||
"hf_model = hf_model.to(device)\n", | ||
"hf_model.eval()\n", | ||
"\n", | ||
"with torch.no_grad():\n", | ||
" hf_output = hf_model(**hf_data)\n", | ||
"\n", | ||
"print(hf_output.pooler_output.shape)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Local Example\n", | ||
"\n", | ||
"Download the model weights from [Google Drive](https://drive.google.com/file/d/12ICd_-yJWMYYo5OskMmc9SHJAQivAtS7/view?usp=sharing) and put them in the root of the hibou directory.\n", | ||
"\n", | ||
"The cell below should work without installing anything, but if you'd like to use the model from anywhere, `cd` to the hibou directory and run:\n", | ||
"```bash\n", | ||
"pip install -r requirements.txt && pip install -e .\n", | ||
"```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from hibou import build_model\n", | ||
"\n", | ||
"model = build_model(weights_path=\"hibou-b.pth\")\n", | ||
"\n", | ||
"print(\"Total parameters:\", sum(p.numel() for p in model.parameters()))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Get the features" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"transforms = torchvision.transforms.Compose([\n", | ||
" torchvision.transforms.Resize((224, 224), interpolation=torchvision.transforms.InterpolationMode.BICUBIC),\n", | ||
" torchvision.transforms.CenterCrop((224, 224)),\n", | ||
" torchvision.transforms.ToTensor(),\n", | ||
" torchvision.transforms.Normalize(mean=[0.7068, 0.5755, 0.7220], std=[0.1950, 0.2316, 0.1816]),\n", | ||
"])\n", | ||
"\n", | ||
"data = transforms(image).unsqueeze(0).to(device)\n", | ||
"model = model.to(device)\n", | ||
"model.eval()\n", | ||
"\n", | ||
"with torch.no_grad():\n", | ||
" output = model(data)\n", | ||
"\n", | ||
"print(output.shape)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Or let's say you're building a segmentation model and for that you want to get intermediate features" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"with torch.no_grad():\n", | ||
" extended_output = model.forward_features(data, return_intermediate=True)\n", | ||
"\n", | ||
"print(extended_output.keys())\n", | ||
"print(f\"Total intermediate outputs: {len(extended_output['intermediate'])}\", f\"\\nThe shape of the intermediate output: {extended_output['intermediate'][-1].shape}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### If you've run both the hugingface hub and the local installation then you can run the cell below to check that the outputs are very close (or similar). The difference might be due to rounding errors but it should be very small." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"print((output - hf_output.pooler_output).mean())" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "ML", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.16" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .models import build_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the Apache License, Version 2.0 | ||
# found in the LICENSE file in the root directory of this source tree. | ||
# | ||
# Portions Copyright (c) HistAI Inc. | ||
|
||
import torch | ||
from . import vision_transformer | ||
|
||
|
||
def build_model( | ||
weights_path=None, | ||
img_size=224, | ||
arch="vit_base", | ||
patch_size=14, | ||
layerscale=1e-5, | ||
ffn_layer="swiglufused", | ||
block_chunks=0, | ||
qkv_bias=True, | ||
proj_bias=True, | ||
ffn_bias=True, | ||
num_register_tokens=4, | ||
interpolate_offset=0, | ||
interpolate_antialias=True, | ||
): | ||
vit_kwargs = dict( | ||
img_size=img_size, | ||
patch_size=patch_size, | ||
init_values=layerscale, | ||
ffn_layer=ffn_layer, | ||
block_chunks=block_chunks, | ||
qkv_bias=qkv_bias, | ||
proj_bias=proj_bias, | ||
ffn_bias=ffn_bias, | ||
num_register_tokens=num_register_tokens, | ||
interpolate_offset=interpolate_offset, | ||
interpolate_antialias=interpolate_antialias, | ||
) | ||
model = vision_transformer.__dict__[arch](**vit_kwargs) | ||
if weights_path is not None: | ||
print(model.load_state_dict(torch.load(weights_path), strict=False)) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the Apache License, Version 2.0 | ||
# found in the LICENSE file in the root directory of this source tree. | ||
|
||
from .dino_head import DINOHead | ||
from .mlp import Mlp | ||
from .patch_embed import PatchEmbed | ||
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused | ||
from .block import NestedTensorBlock | ||
from .attention import MemEffAttention |
Oops, something went wrong.