From 1b2ad57ab18b464daf647e73711c3589b8a6f307 Mon Sep 17 00:00:00 2001 From: statelesshz Date: Sat, 2 Nov 2024 17:36:09 +0800 Subject: [PATCH 1/2] fix load_state_dict for npu --- src/accelerate/utils/modeling.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 5f88e54e3c9..934b5b44776 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1670,9 +1670,11 @@ def load_state_dict(checkpoint_file, device_map=None): if len(set(device_map.values())) == 1: device = list(device_map.values())[0] target_device = device - if is_xpu_available(): - if isinstance(device, int): + if isinstance(device, int): + if is_xpu_available(): target_device = f"xpu:{device}" + elif is_npu_available(): + target_device = f"npu:{device}" return safe_load_file(checkpoint_file, device=target_device) From 0a923ff419023e672d7ca724567dff32d532364d Mon Sep 17 00:00:00 2001 From: statelesshz Date: Wed, 4 Dec 2024 14:44:01 +0800 Subject: [PATCH 2/2] update --- src/accelerate/utils/modeling.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 934b5b44776..73ada66544f 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1706,9 +1706,11 @@ def load_state_dict(checkpoint_file, device_map=None): progress_bar = None for device in devices: target_device = device - if is_xpu_available(): - if isinstance(device, int): + if isinstance(device, int): + if is_xpu_available(): target_device = f"xpu:{device}" + elif is_npu_available(): + target_device = f"npu:{device}" with safe_open(checkpoint_file, framework="pt", device=target_device) as f: for key in device_weights[device]: