Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <[email protected]>
  • Loading branch information
Kaihui-intel committed Sep 30, 2024
1 parent c2fb59c commit 8f63ecc
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions docs/source/3x/transformers_like_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=
q_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="xpu", trust_remote_code=True)

# optimize the model with ipex, it will improve performance.
quantization_config = q_model.quantization_config if hasattr(user_model, "quantization_config") else None
quantization_config = q_model.quantization_config if hasattr(q_model, "quantization_config") else None
q_model = ipex.optimize_transformers(
q_model, inplace=True, dtype=torch.float16, quantization_config=quantizaiton_config, device="xpu"
)
Expand Down Expand Up @@ -205,7 +205,7 @@ q_model.save_pretrained("saved_dir")
loaded_model = AutoModelForCausalLM.from_pretrained("saved_dir", trust_remote_code=True)

# Before executed the loaded model, you can call ipex.optimize_transformers function.
quantization_config = q_model.quantization_config if hasattr(user_model, "quantization_config") else None
quantization_config = q_model.quantization_config if hasattr(q_model, "quantization_config") else None
loaded_model = ipex.optimize_transformers(
loaded_model, inplace=True, dtype=torch.float16, quantization_config=quantization_config, device="xpu"
)
Expand All @@ -217,7 +217,7 @@ prompt = "Once upon a time, a little girl"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4)
gen_ids = q_model.generate(input_ids, **generate_kwargs)
gen_ids = loaded_model.generate(input_ids, **generate_kwargs)
gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
print(gen_text)
```
Expand Down

0 comments on commit 8f63ecc

Please sign in to comment.