diff --git a/kogpt/inference.py b/kogpt/inference.py index b7167b3..5084adc 100644 --- a/kogpt/inference.py +++ b/kogpt/inference.py @@ -29,7 +29,7 @@ def __init__( pretrained_model_name_or_path, revision=revision, pad_token_id=self.tokenizer.eos_token_id, torch_dtype='auto', low_cpu_mem_usage=True - ) + ).to(device) LOGGER.debug('loaded weights') LOGGER.debug('#parameters: %d', sum([p.numel() for p in self.model.parameters()]))