diff --git a/denoising_diffusion_pytorch/attend.py b/denoising_diffusion_pytorch/attend.py index 333b75113..30e0f0edf 100644 --- a/denoising_diffusion_pytorch/attend.py +++ b/denoising_diffusion_pytorch/attend.py @@ -60,7 +60,9 @@ def __init__( device_properties = torch.cuda.get_device_properties(torch.device('cuda')) - if device_properties.major == 8 and device_properties.minor == 0: + device_version = version.parse(f'{device_properties.major}.{device_properties.minor}') + + if device_version > version.parse('8.0'): print_once('A100 GPU detected, using flash attention if input tensor is on cuda') self.cuda_config = AttentionConfig(True, False, False) else: diff --git a/denoising_diffusion_pytorch/version.py b/denoising_diffusion_pytorch/version.py index 2940e26c2..fca23f681 100644 --- a/denoising_diffusion_pytorch/version.py +++ b/denoising_diffusion_pytorch/version.py @@ -1 +1 @@ -__version__ = '2.0.8' +__version__ = '2.0.10'