Skip to content

Latest commit

 

History

History
66 lines (44 loc) · 2.18 KB

README.md

File metadata and controls

66 lines (44 loc) · 2.18 KB

PaDiM

A Patch Distribution Modeling Framework for Anomaly Detection and Localization

This is an unofficial re-implementation of the paper PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization available on arxiv.

Features

The key features of this implementation are:

  • Constant memory footprint - training on more images does not result in more memory required
  • Resumable learning - the training step can be stopped and then resumed with inference in-between
  • Limited dependencies - apart from PyTorch, Torchvision and Numpy

Variants

This repository also contains variants on the original PaDiM model:

  • PaDiMSVDD uses a Deep-SVDD model instead of a multi-variate Gaussian distribution for the normal patch representation.
  • PaDiMShared shares the multi-variate Gaussian distribution between all patches instead of learning it only for specific coordinates.

Installation

git clone https://github.com/Pangoraw/PaDiM.git padim

Getting started

Training

from torch.utils.data import DataLoader
from padim import PaDiM

# i) Initialize
padim = PaDiM(num_embeddings=100, device="cpu", backbone="resnet18") 

# ii) Create a dataloader producing image tensors
dataloader = DataLoader(...)

# iii) Consume the data to learn the normal distribution
# Use PaDiM.train(...)
padim.train(dataloader)

# Or PaDiM.train_one_batch(...)
for imgs in dataloader:
	padim.train_one_batch(imgs)

Testing

With the same PaDiM instance as in the Training section:

for new_imgs in test_dataloader:
	distances = padim.predict(new_imgs) 
	# distances is a (n * c) matrix of the mahalanobis distances
	# Compute metrics...

Acknowledgements

This implementation was built on the work of: