diff --git a/src/llm_vm/onsite_llm.py b/src/llm_vm/onsite_llm.py index f0154b1c..a70f5a08 100644 --- a/src/llm_vm/onsite_llm.py +++ b/src/llm_vm/onsite_llm.py @@ -27,6 +27,7 @@ from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training from trl import SFTTrainer from sentence_transformers import SentenceTransformer +from vllm import LLM, SamplingParams @@ -86,7 +87,7 @@ def __getitem__(self, idx): return self.dataset[idx] class BaseOnsiteLLM(ABC): - def __init__(self,model_uri=None, tokenizer_kw_args={}, model_kw_args={}): + def __init__(self,model_uri=None, vllm_support=True, tokenizer_kw_args={}, model_kw_args={}): if model_uri != None : self.model_uri= model_uri if model_uri is None and self.model_uri is None: @@ -94,6 +95,7 @@ def __init__(self,model_uri=None, tokenizer_kw_args={}, model_kw_args={}): self.model_name : str = self.model_uri.split('/')[-1] # our default for deriving model name self.model=self.model_loader(**model_kw_args) self.tokenizer=self.tokenizer_loader(**tokenizer_kw_args) + self.vllm_support=vllm_support # Move the model to the specified device(s) if isinstance(device, list): @@ -145,6 +147,12 @@ def generate(self,prompt,max_length=100, tokenizer_kwargs={}, generation_kwargs= I think it takes about a week for the apple to grow. """ + if generation_kwargs['num_return_sequences']>1 and self.vllm_support: + print("doing parallel sampling using vllm") + sampling_params = SamplingParams(n=generation_kwargs['num_return_sequences'], max_tokens=max_length) + llm = LLM(model=self.model_uri) + outputs = llm.generate(prompt, sampling_params) + return [outputs[0].outputs[i].text for i in range(generation_kwargs['num_return_sequences'])] if isinstance(device, list): # If multiple GPUs are available, use first one @@ -442,6 +450,7 @@ class SmallLocalNeo(BaseOnsiteLLM): generate: Generates a response from a given prompt with the loaded LLM and tokenizer """ model_uri="EleutherAI/gpt-neo-1.3B" + vllm_support = False def model_loader(self): return GPTNeoForCausalLM.from_pretrained(self.model_uri) @@ -680,6 +689,7 @@ class SmallLocalFlanT5(BaseOnsiteLLM): """ model_uri="google/flan-t5-small" + vllm_support = False def model_loader(self): return AutoModelForSeq2SeqLM.from_pretrained(self.model_uri) def tokenizer_loader(self): @@ -704,6 +714,7 @@ class SmallLocalBERT(BaseOnsiteLLM): """ model_uri = "bert-base-cased" + vllm_support = False def model_loader(self): return AutoModelForMaskedLM.from_pretrained(self.model_uri) def tokenizer_loader(self):