Skip to content

Commit

Permalink
speedup and siglip
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyanghuang7 committed Nov 2, 2024
1 parent 9271e16 commit c41b22a
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 28 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
model/clip_model/clip-vit-base-patch32/
model/siglip_model/siglip-vit-base-patch16/
out/*.pth
full.json
trans_json.py
Expand Down
9 changes: 7 additions & 2 deletions 1-pretrain_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def init_model(lm_config):

print(f'模型可学习参数: {count_parameters(model) / 1e6} 百万 = {count_parameters(model) / 1e9} B (Billion)')

(vision_model, preprocess) = get_vision_model()
(vision_model, preprocess) = get_vision_model(args.visual_encoder)
vision_model = vision_model.to(args.device)
return model, tokenizer, (vision_model, preprocess)

Expand Down Expand Up @@ -166,10 +166,15 @@ def init_distributed_mode():
parser.add_argument("--log_interval", type=int, default=10, help="Logging interval")
parser.add_argument("--save_interval", type=int, default=100, help="Model saving interval")
parser.add_argument('--local_rank', type=int, default=-1, help='local rank for distributed training')
parser.add_argument('--visual_encoder', type=str, default="clip", help='type of visual endcoder')

args = parser.parse_args()

lm_config = LMConfig()
if args.visual_encoder == "clip":
lm_config = LMConfig(image_special_token='<'*2+'>'*2, image_ids=[30]*2+[32]*2)
else:
lm_config = LMConfig(image_special_token='<'*2+'>'*2, image_ids=[30]*2+[32]*2)

max_seq_len = lm_config.max_seq_len
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
Expand Down
9 changes: 7 additions & 2 deletions 2-sft_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def init_model(lm_config):

print(f'模型可学习参数: {count_parameters(model) / 1e6} 百万 = {count_parameters(model) / 1e9} B (Billion)')

(vision_model, preprocess) = get_vision_model()
(vision_model, preprocess) = get_vision_model(args.visual_encoder)
vision_model = vision_model.to(args.device)
return model, tokenizer, (vision_model, preprocess)

Expand Down Expand Up @@ -190,10 +190,15 @@ def init_distributed_mode():
parser.add_argument('--local_rank', type=int, default=-1, help='local rank for distributed training')
parser.add_argument('--multi', type=bool, default=False, help='multi-images training')
parser.add_argument('--save_last', type=bool, default=True, help='save last step model')
parser.add_argument('--visual_encoder', type=str, default="clip", help='type of visual endcoder')

args = parser.parse_args()

lm_config = LMConfig()
if args.visual_encoder == "clip":
lm_config = LMConfig(image_special_token='<'*2+'>'*2, image_ids=[30]*2+[32]*2)
else:
lm_config = LMConfig(image_special_token='<'*2+'>'*2, image_ids=[30]*2+[32]*2)

max_seq_len = lm_config.max_seq_len
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
Expand Down
Binary file modified model/__pycache__/LMConfig.cpython-310.pyc
Binary file not shown.
Binary file modified model/__pycache__/dataset.cpython-310.pyc
Binary file not shown.
Binary file modified model/__pycache__/model.cpython-310.pyc
Binary file not shown.
Binary file modified model/__pycache__/vision_utils.cpython-310.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion model/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __getitem__(self, index: int):
sample = self.data[index]
image_name = sample['image']
conversation = sample['conversations']
# minimind-v的image的特殊占位符,对应每张图切分成10个token,和get_img_process中的数量对应
# minimind-v的image的特殊占位符,对应每张图切分成M个token,和get_img_process中的数量对应
messages = []
# 遍历 conversation 列表
for i in range(0, len(conversation), 2):
Expand Down
33 changes: 19 additions & 14 deletions model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,23 +326,23 @@ class Transformer(PreTrainedModel):
config_class = LMConfig
last_loss: Optional[torch.Tensor]

def __init__(self, params: LMConfig = None):
def __init__(self, params: LMConfig = None, vocab_size = 6400):
super().__init__(params)
if not params:
params = LMConfig()
self.params = params
self.vocab_size = params.vocab_size
self.vocab_size = vocab_size
self.n_layers = params.n_layers
# image的特殊占位符,对应每张图切分成M个token,和get_img_process中的数量对应
self.image_ids = params.image_ids

self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.tok_embeddings = nn.Embedding(self.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.layers = torch.nn.ModuleList()
for layer_id in range(self.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.output = nn.Linear(params.dim, self.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
pos_cis = precompute_pos_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
self.register_buffer("pos_cis", pos_cis, persistent=False)
Expand Down Expand Up @@ -372,30 +372,35 @@ def count_vision_proj(self, tokens, h, image_encoders=None, seqlen=200):
# 查找token中<image>片段的索引,为了替换做准备
def find_indices(tokens, image_ids):
image_ids_tensor = torch.tensor(image_ids).to(tokens.device)
indices = []
len_image_ids = len(image_ids)

for batch_idx in range(tokens.size(0)):
for i in range(tokens.size(1) - len(image_ids) + 1):
if torch.equal(tokens[batch_idx, i:i + len(image_ids)], image_ids_tensor):
indices.append([batch_idx, i, i + len(image_ids) - 1]) # 返回batch_idx和开始结束索引
# 使用view来创建一个视图,便于处理滑动窗口
tokens_view = tokens.unfold(1, len_image_ids, 1) # 在第二维度创建滑动窗口
# 检查每个滑动窗口是否与image_ids_tensor相等
matches = (tokens_view == image_ids_tensor).all(dim=2) # 对窗口中的每一行进行比较

# 提取匹配的索引
indices = {}
for batch_idx in range(tokens.size(0)):
match_indices = matches[batch_idx].nonzero(as_tuple=True)[0] # 获取非零(匹配)索引
if match_indices.numel() > 0: # 如果有匹配
indices[batch_idx] = [(idx.item(), idx.item() + len_image_ids - 1) for idx in match_indices]
return indices if indices else None

image_indices = find_indices(tokens,
self.image_ids) # [0, 4, 53], [0, 54, 103], [0, 104, 153], [0, 154, 203] or [1, 4, 53], [1, 54, 103]
image_indices = find_indices(tokens, self.image_ids) # 字典形式存储索引

# 如果此时有图像编码
if image_encoders is not None:
vision_proj = self.vision_proj(image_encoders)
vision_proj = vision_proj.unsqueeze(0) if len(vision_proj.shape) == 3 else vision_proj
vision_proj = vision_proj.unsqueeze(1) if len(vision_proj.shape) == 3 else vision_proj
if image_indices is not None:
# 创建一个新的张量来存储拼接后的结果
new_h = []
for i in range(h.size(0)):
# i即为current_batch_idx索引
img_idx = 0
for batch_idx, start_idx, end_idx in image_indices:
if batch_idx == i:
if i in image_indices: # 直接从字典中获取
for start_idx, end_idx in image_indices[i]:
# 插入vision_proj特征
before = h[i][:start_idx, :]
after = h[i][end_idx + 1:, :]
Expand Down
5 changes: 5 additions & 0 deletions model/siglip_model/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
* 需要把siglip-base-patch16-224模型下载到此目录下

```bash
git clone https://hf-mirror.com/google/siglip-base-patch16-224
```
26 changes: 17 additions & 9 deletions model/vision_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from transformers import CLIPProcessor, CLIPModel
from transformers import CLIPProcessor, CLIPModel, SiglipProcessor, SiglipModel
from PIL import Image
import requests
import torch
Expand All @@ -8,19 +8,27 @@
warnings.filterwarnings('ignore')


def get_vision_model():
def get_vision_model(encoder_type):
# 加载预训练的CLIP模型和处理器
model_path = "./model/clip_model/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_path)
processor = CLIPProcessor.from_pretrained(model_path)
if encoder_type == "clip":
model_path = "./model/clip_model/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_path)
processor = CLIPProcessor.from_pretrained(model_path)
else:
model_path = "./model/siglip_model/siglip-vit-base-patch16"
model = SiglipModel.from_pretrained(model_path)
processor = SiglipProcessor.from_pretrained(model_path)
return (model, processor)


def get_img_process(image, processor):
# 将图像调整为144*144大小
# 将图像调整为224*224大小
image = image.resize((224, 224))
if image.mode in ['RGBA', 'LA']: # 处理有透明通道的图像
image = image.convert('RGB')
# 使用CLIPProcessor处理每个patch
inputs = processor(images=image, return_tensors="pt", clean_up_tokenization_spaces=False)
# inputs = processor(images=image, return_tensors="pt", clean_up_tokenization_spaces=False)
inputs = processor(images=image, return_tensors="pt")
return inputs


Expand All @@ -32,7 +40,7 @@ def hook_fn(module, input, output):
embeddings.append(output.last_hidden_state)

# 从 BatchEncoding 中提取图像张量
if isinstance(batch_encoding, transformers.tokenization_utils_base.BatchEncoding):
if isinstance(batch_encoding, transformers.tokenization_utils_base.BatchEncoding) or isinstance(batch_encoding, transformers.feature_extraction_utils.BatchFeature):
image_tensor = batch_encoding['pixel_values']
else:
image_tensor = batch_encoding # torch.Size([32, 4, 3, 224, 224])
Expand All @@ -58,5 +66,5 @@ def hook_fn(module, input, output):
hook.remove()

# 拼接所有特征向量成为一个张量
all_embeddings = torch.stack(embeddings, dim=0).squeeze() # torch.Size([32, 4, 50, 768])
all_embeddings = torch.stack(embeddings, dim=0).squeeze() # torch.Size([32, 4, 50, 768]) or torch.Size([32, 2, 196, 768])
return all_embeddings

0 comments on commit c41b22a

Please sign in to comment.