Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable llava on torchchat #1183

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

enable llava on torchchat #1183

wants to merge 9 commits into from

Conversation

Gasoonjia
Copy link
Contributor

This PR enable llava1.5 on torchchat, which is the first multi-modality model on torchchat.

How to play?

You can use --prompt as the flag for text input, and --image-prompt as image input.
e.g.

(torchchat) [ ~/torchchat (9e4350d7b)]$ python torchchat.py generate llava-1.5 --prompt "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image> What are the things I should be cautious about when I visit here? ASSISTANT:" --image-prompt ../view.jpg
Using device=cuda NVIDIA PG509-210
Loading model...
Time to load model: 5.16 seconds
-----------------------------------------------------------
Image prompts ['../view.jpg']
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image> What are the things I should be cautious about when I visit here? ASSISTANT: When visiting this vibrant and fascinating place, one should be cautious about the potential for the area to be crowded or filled with tourists. This might lead to overcrowding and a loss of personal space. Additionally, the vibrant design and colorful patterns might be visually stimulating, so it's essential to be cautious while taking photographs, ensuring not to bump into other people or unintentionally obstruct their views. It's also important to be aware of your surroundings and belongings, as the brightness of the colors and intricate patterns can make them harder to spot. Lastly, it's always a good idea to respect the cultural and artistic value of such places, by refraining from touching or interacting with the artwork without permission or being mindful of your surroundings.2024-09-23:20:24:10,645 INFO     [generate.py:1031] 
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~                
Generated 185 tokens                 
Time for inference 1: 11.4913 sec total                 
Time to first token: 0.6280 sec with parallel prefill.                

      Total throughput: 16.1861 tokens/sec, 0.0618 s/token                 
First token throughput: 1.5923 tokens/sec, 0.6280 s/token                 
 Next token throughput: 17.0298 tokens/sec, 0.0587 s/token                     
2024-09-23:20:24:10,645 INFO     [generate.py:1042] 
Bandwidth achieved: 228.66 GB/s
2024-09-23:20:24:10,645 INFO     [generate.py:1046] *** This first iteration will include cold start effects for dynamic import, hardware caches. ***

========================================

It can also handle input without image input:

(torchchat) [ ~/torchchat (9e4350d7b)]$ python torchchat.py generate llava-1.5 --prompt "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER:  What are the things I should be cautious about when I visit Canada? ASSISTANT:"
Using device=cuda NVIDIA PG509-210
Loading model...
Time to load model: 5.50 seconds
-----------------------------------------------------------
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER:  What are the things I should be cautious about when I visit Canada? ASSISTANT: There are several things you should be cautious about when visiting Canada, including:

1. Health and safety: Canada has a generally safe environment, but as with any country, you should be mindful of your surroundings and take precautions to stay safe. This includes being cautious of pickpockets in crowded areas, watching out for traffic when crossing the street, and avoiding potential hazards in public spaces.
2. Weather: Canada has a varied climate, with different regions experiencing different weather conditions. Be prepared for unexpected changes in weather and dress appropriately for the climate you will be visiting.
3. Customs regulations: When bringing items into Canada, you must declare any goods that are subject to customs duty or tax. There are also restrictions on bringing certain items into the country, such as food and plants.
4. Language: While English and French are the official languages of Canada, not all2024-09-23:20:25:49,785 INFO     [generate.py:1031] 
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~                
Generated 199 tokens                 
Time for inference 1: 13.8391 sec total                 
Time to first token: 0.5637 sec with parallel prefill.                

      Total throughput: 14.4518 tokens/sec, 0.0692 s/token                 
First token throughput: 1.7741 tokens/sec, 0.5637 s/token                 
 Next token throughput: 14.9901 tokens/sec, 0.0667 s/token                     
2024-09-23:20:25:49,786 INFO     [generate.py:1042] 
Bandwidth achieved: 204.16 GB/s
2024-09-23:20:25:49,786 INFO     [generate.py:1046] *** This first iteration will include cold start effects for dynamic import, hardware caches. ***

========================================

Copy link

pytorch-bot bot commented Sep 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1183

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 937e7ed with merge base 2cf4016 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 24, 2024
@Gasoonjia Gasoonjia changed the title onboarding llava on torchchat enable llava on torchchat Sep 24, 2024
Copy link
Contributor

@Jack-Khuu Jack-Khuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to test that Flamingo and 3.1 still work as expected

Also reminder that to test convert_hf_checkpoint you need to delete your download/conversion and rerun

@@ -21,9 +24,176 @@

from torchchat.model import ModelArgs

def remap_llava_checkpoint(llava_ckpt):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this written inhouse?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not pretty following your question.
This function is consumed by convert_llava_checkpoint to get remapped checkpoint.
I made this as an individual function to simply the logic

@@ -21,9 +24,176 @@

from torchchat.model import ModelArgs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Llava Conversion Code
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code comment blocks to help us move things around later


tokenizer_path = model_dir / "tokenizer.model"
shutil.copy(tokenizer_files[0], tokenizer_path)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Text-Only Conversion Code
"""

if batch and self.model.config.model_type == ModelType.Llava:
context_len, next_token = next_token
else:
context_len, next_token = T, next_token
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
context_len, next_token = T, next_token
context_len = T

encoded = batch["tokens"]
elif self.model.config.model_type == ModelType.Llava:
#TODO: double check the tokenizer.
def find_subtensor(tensor, target):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typehints

Comment on lines +969 to +984
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outdated comment?

@@ -0,0 +1,80 @@
import torch
import torchvision as tv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lintrunner ordering

padding with median RGB value to make a square, scaling, and normalizing.

Args:
img_address (str): Address of the local image file will be forwarded to the model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Autogen'd comment?

@@ -919,6 +937,58 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
return x_out2.type_as(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can move apply_rotary_emb so that it is sequentially after hf_apply_rotary_emb?

Mainly for keeping concepts together

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to keep the current structure, with all HF rotary embedding functions grouped together and all previous embedding functions in a separate section.

encoded = batch["tokens"]
assert len(images) == 1, "Only one image prompt is supported for now"

#TODO: updated encoded variable for multi-modality models to include image tokens.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this to me?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants