Skip to content

Commit

Permalink
Adapted Vicuna and ChatGLM models on 310p NPU
Browse files Browse the repository at this point in the history
  • Loading branch information
zer0py2c committed Sep 22, 2024
1 parent 3773213 commit 1141f19
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 6 deletions.
32 changes: 26 additions & 6 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,30 @@ def load_model(
)
elif device == "npu":
kwargs = {"torch_dtype": torch.float16}
# Try to load ipex, while it looks unused, it links into torch for xpu support
# Adapted version 2.1.0.post6 of torch-npu for the Ascend 310p chips.
from pkg_resources import parse_version, get_distribution
try:
import torch_npu

required_version = "2.1.0.post6"
installed_version = get_distribution("torch_npu").version
if parse_version(installed_version) != parse_version(required_version):
warnings.warn(f"The version of torch-npu is not {required_version}.")
except ImportError:
warnings.warn("Ascend Extension for PyTorch is not installed.")

if num_gpus != 1:
num_npus = num_gpus
kwargs["device_map"] = "balanced"
if max_gpu_memory is None:
kwargs["device_map"] = "sequential"
available_gpu_memory = get_npu_memory(num_npus)
kwargs["max_memory"] = {
i: str(int(available_gpu_memory[i] * 0.8)) + "GiB"
for i in range(num_npus)
}
else:
kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_npus)}
else:
raise ValueError(f"Invalid device: {device}")

Expand Down Expand Up @@ -373,11 +392,8 @@ def load_model(
):
model = ipex.optimize(model, dtype=kwargs["torch_dtype"])

if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or device in (
"mps",
"xpu",
"npu",
):
if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or \
(device == "npu" and num_gpus == 1) or device in ("mps", "xpu"):
model.to(device)

if device == "xpu":
Expand Down Expand Up @@ -702,6 +718,8 @@ def match(self, model_path: str):
return "vicuna" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
# Disable JIT just-in-time compilation
torch.npu.set_compile_mode(jit_compile=False)
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=self.use_fast_tokenizer, revision=revision
Expand Down Expand Up @@ -843,6 +861,8 @@ def match(self, model_path: str):
return "chatglm" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
# Disable JIT on-the-fly compilation
torch.npu.set_compile_mode(jit_compile=False)
revision = from_pretrained_kwargs.get("revision", "main")
if "chatglm3" in model_path.lower():
tokenizer = AutoTokenizer.from_pretrained(
Expand Down
22 changes: 22 additions & 0 deletions fastchat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,28 @@ def get_gpu_memory(max_gpus=None):
return gpu_memory


def get_npu_memory(max_npus=None):
"""Get available memory for each NPU."""
import torch_npu

npu_memory = []
num_npus = (
torch_npu.npu.device_count()
if max_npus is None
else min(max_npus, torch_npu.npu.device_count())
)

for npu_id in range(num_npus):
with torch_npu.npu.device(npu_id):
device = torch_npu.npu.current_device()
npu_properties = torch_npu.npu.get_device_properties(device)
total_memory = npu_properties.total_memory / (1024**3)
allocated_memory = torch_npu.npu.memory_allocated() / (1024**3)
available_memory = total_memory - allocated_memory
npu_memory.append(available_memory)
return npu_memory


def oai_moderation(text, custom_thresholds=None):
"""
Check whether the text violates OpenAI moderation API.
Expand Down

0 comments on commit 1141f19

Please sign in to comment.