diff --git a/neural_compressor/transformers/quantization/utils.py b/neural_compressor/transformers/quantization/utils.py index f29613d547c..b9d171615d5 100644 --- a/neural_compressor/transformers/quantization/utils.py +++ b/neural_compressor/transformers/quantization/utils.py @@ -525,6 +525,11 @@ def convert_to_quantized_model(model, config, device="cpu"): if orig_dtype != torch.float32: q_model.to(dtype=orig_dtype) + if config.use_layer_wise and not (q_model.device == device or q_model.device.type == device): + logger.warning( + "Do not convert device to avoid out of memory. Recommend using saved quantized model to inference.") + return q_model + return q_model.to(device)