From c41b22a508a1ba068181b3970fd351b85406d6a0 Mon Sep 17 00:00:00 2001 From: xinyanghuang7 Date: Sat, 2 Nov 2024 03:21:22 -0400 Subject: [PATCH] speedup and siglip --- .gitignore | 1 + 1-pretrain_vlm.py | 9 +++-- 2-sft_vlm.py | 9 +++-- model/__pycache__/LMConfig.cpython-310.pyc | Bin 1601 -> 1601 bytes model/__pycache__/dataset.cpython-310.pyc | Bin 6234 -> 6234 bytes model/__pycache__/model.cpython-310.pyc | Bin 16153 -> 16442 bytes .../__pycache__/vision_utils.cpython-310.pyc | Bin 1672 -> 1923 bytes model/dataset.py | 2 +- model/model.py | 33 ++++++++++-------- model/siglip_model/README.md | 5 +++ model/vision_utils.py | 26 +++++++++----- 11 files changed, 57 insertions(+), 28 deletions(-) create mode 100644 model/siglip_model/README.md diff --git a/.gitignore b/.gitignore index 5f10ded..f0b6441 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ model/clip_model/clip-vit-base-patch32/ +model/siglip_model/siglip-vit-base-patch16/ out/*.pth full.json trans_json.py diff --git a/1-pretrain_vlm.py b/1-pretrain_vlm.py index 8cec330..5a8270d 100644 --- a/1-pretrain_vlm.py +++ b/1-pretrain_vlm.py @@ -128,7 +128,7 @@ def init_model(lm_config): print(f'模型可学习参数: {count_parameters(model) / 1e6} 百万 = {count_parameters(model) / 1e9} B (Billion)') - (vision_model, preprocess) = get_vision_model() + (vision_model, preprocess) = get_vision_model(args.visual_encoder) vision_model = vision_model.to(args.device) return model, tokenizer, (vision_model, preprocess) @@ -166,10 +166,15 @@ def init_distributed_mode(): parser.add_argument("--log_interval", type=int, default=10, help="Logging interval") parser.add_argument("--save_interval", type=int, default=100, help="Model saving interval") parser.add_argument('--local_rank', type=int, default=-1, help='local rank for distributed training') + parser.add_argument('--visual_encoder', type=str, default="clip", help='type of visual endcoder') args = parser.parse_args() - lm_config = LMConfig() + if args.visual_encoder == "clip": + lm_config = LMConfig(image_special_token='<'*2+'>'*2, image_ids=[30]*2+[32]*2) + else: + lm_config = LMConfig(image_special_token='<'*2+'>'*2, image_ids=[30]*2+[32]*2) + max_seq_len = lm_config.max_seq_len args.save_dir = os.path.join(args.out_dir) os.makedirs(args.save_dir, exist_ok=True) diff --git a/2-sft_vlm.py b/2-sft_vlm.py index 768a6e1..15c5d04 100644 --- a/2-sft_vlm.py +++ b/2-sft_vlm.py @@ -148,7 +148,7 @@ def init_model(lm_config): print(f'模型可学习参数: {count_parameters(model) / 1e6} 百万 = {count_parameters(model) / 1e9} B (Billion)') - (vision_model, preprocess) = get_vision_model() + (vision_model, preprocess) = get_vision_model(args.visual_encoder) vision_model = vision_model.to(args.device) return model, tokenizer, (vision_model, preprocess) @@ -190,10 +190,15 @@ def init_distributed_mode(): parser.add_argument('--local_rank', type=int, default=-1, help='local rank for distributed training') parser.add_argument('--multi', type=bool, default=False, help='multi-images training') parser.add_argument('--save_last', type=bool, default=True, help='save last step model') + parser.add_argument('--visual_encoder', type=str, default="clip", help='type of visual endcoder') args = parser.parse_args() - lm_config = LMConfig() + if args.visual_encoder == "clip": + lm_config = LMConfig(image_special_token='<'*2+'>'*2, image_ids=[30]*2+[32]*2) + else: + lm_config = LMConfig(image_special_token='<'*2+'>'*2, image_ids=[30]*2+[32]*2) + max_seq_len = lm_config.max_seq_len args.save_dir = os.path.join(args.out_dir) os.makedirs(args.save_dir, exist_ok=True) diff --git a/model/__pycache__/LMConfig.cpython-310.pyc b/model/__pycache__/LMConfig.cpython-310.pyc index 3b4c852d1562686d6a1ace7212d27aaee07cb307..fdc47c4d11b047d042095a122eaa8429a037187e 100644 GIT binary patch delta 20 acmX@ebC8ESpO=@50SGcrscz&pV*>y)9R!I0 delta 20 acmX@ebC8ESpO=@50SE#G6*qF5u>k-s3j?$O diff --git a/model/__pycache__/dataset.cpython-310.pyc b/model/__pycache__/dataset.cpython-310.pyc index ef0806eb211cfd5d2d403703d0d342e0e006898e..93d01323da016bb05678cbc6a78a1c7b19cac0bc 100644 GIT binary patch delta 22 ccmca*aLa%RwASumnLhCvc7**H8eT^gP2U4;6=Bv=K{Boo99Y` zTSamci-$UcTaX__*v44sWqCW2x3jM5$XucH8Ikm|#pMr8UXR=;K&*xKBfT3zMYt1T zBf>6(Z3y2$*bLxE`aU)5c+_l`YRJ?P%2DD;susN&hvXBTdt+zoNG1cKnD8IEtB1Svhq-ZARV22zyWzNahVIlP=^7CQTwY!|tu< zk~blFBYUyp)#@Q2(Qi`*@Cg;9DME+Xsmf`RW-V27xAB-okPXmV0Am4K3NKoQ5JTWD zCV+Cpv2J*J;8{siC`tpwytE2wp4?+dJ-|hf+vD}9Kp#Rl1mH+!CR?Bn;`m|qZuNoc z`+=N>XWqGmBpjso&lxo`A&;;^y4bv7`jSnC2~$69A@Y0TiPB&@(BRPV-|9> zh{JCqoJ5{?LZ^9?KF&N1o5?i$dP6O_hwW+jm$DV4gE1f5)*2kL51Rm1| zfT<6Zg(Z5N88r)-vtSzp*Rk-;uU(ciP=j=0QvN5MC=zKeijD%DArvjUk@KLhVK+8z zsdygs$}d~ON7(lo8xoHL+fiVv=+rVP`sGMz)zY1?kudkDRM!$0h5v7Mr}PTN(>Q-M z!cXQFHyt1wEudq>@oLsM%!4Hbp7f)_iwHdNm`taf>Z7`9B{GJk)2y!U)7@YmW$!Oq zxa?X~egXCT1;gB1&~?c8_xU)N*Wsg-PU+84=_M9xZtKSSp+7{x#Lzi}A0a#k5D%rY zFq4%Z9Z6)g{SFzY>*fwNpCHS5`W{8SuO^gog=~WU0Ga5mBfE~$em2mOi>^bKq8l2w zV)WvAOWR=)*|D7myhYe+>keTzp` z8M|HPN5{7qjh$}c=RlXuW$AA65qoB7z4Q?<+3xl# z_Wsfx)#qWb`m2d*N!U-%v(3vkOx9_V?n#M9g%r`enmiSX`!z-LO$qUU7S#e;aEio( ze$g*#^;$^x>hiEsCupS?sg$IZY2hhvyhN90LbHOu5;PmMYOQ9993k;im#amz+9?6X zWk&@`h=+%S*a9cG9nR`?CgeJ&Q^%iGO)HVjo95uhLLC;zs7lqGne>3`7k>&6et3{) zX4!_0CL%JmqmBsdK*!2%k1Y(B&6Haew#Oz%#92awy{oKBn+*HmsnW;T{quVdRR}GqEORR4BqD9`KY(a;5_-Z44J*TohMR4f-Lbr$jI*)j z>pL89!#&pB^^L=L!augMb4<^q_gFyY=Sf;%?=9ci)C&rZWT+4@-%M8=tkWnrJu2MHEmNfaDJ_=KJ8TG4>jM6teNid%)B zA>*$IZ?p4VtBA}ZE8f^tK*paD{sIsSI02|%1NOC%)W2$bZVWgf{Wk)4Y6nn`%v(W% ze$KkPo5>e!d-tYN?3>tWuY@l3bo_2 zL$-on+5vbNa9>Er7yc)S%(8#Q(=z-c5N5o-S6E)8|uj1aa%m+`4)!YYhr^(1mG{zM?S|L``kgwo*ruF@j0$&4^7xOlgKX{!}d;zQzI zaY&$7gCdxeGM1j3Z?cXjqjl~-L*k*VZX_}}H4U9dGcWJA@X--Gv2e~7Vd4hXuxcec z)L1nc3%XZ78IEr=PaV<9(-olPc)*q;7mQRstI-MIxSIiPMH4y}vw`*Gc#T3%&&GU? zY~@G50{sv)~NgbcVn z{uSim=B6R!;pLYeK==oMEq;R)h% zDj;#Z`~aosDz;5BI~QKtEv~&}Xu$pFBD$v0n7d|p%0X3s2PecUhU3{}q~KmW-rFSi zBl9df-CMKv3>W1xhP!^vC^>|17zIJN8Ce-4pUbFOzO1oIcPSP2S#Mk9G1QSwi|T65 z5$WiBcx>sqr6HbMFE3Vll&R}hkeAuR>)u(xbHfj2ey*e1#s3l*Z`=o!n}EMJ8tuVY z&azYM8_B2ah4srCW8xtF6!qaPfiNZw4h{^$eP%IgwlHsBFS(EP_eJF|K#NA$?!Gz& zuNQO;VFz0q$OZOfU)`ckku-fNw}?a}C11*I6@*?uql7NQ*J3)!e%H5#yw1utERjdh n&>^;d!>yg0k-QIq|F6T#ntKsDj?_5>-q+3}b%DLK;kN$*;A^aZ delta 3743 zcmZ`+eQXrh5#P7Fw|9Hz^T!!;{sL@cj1S{40wyNLCeXTk)r3GOxdO-H-Luc!?U}c0 zz|I~vvYn))v`HQ*AqZlt|0JrCkXosdG!2!iL`u}EZPYe&OPfGTDWv_QUrLoqO=sTP zfT()X{&wEX?7TPgX5O27bBvrhmY)rWg97~Ci`Mk5+xBX{7$7lL6gWoKvL^%ep)L^7 zC_)GOQDBmEu}x|-Sa#+!We^{1CzejD_D(b|AT%HJ3zZhVow$NjqCy^wj7E=J~q z;08dRGE-(BO|o^7(iOY$Bhu~qt!&cHXh&%c>Mlm;Ls76dYuc&4q3n=F4FjNGm)m;JJqyC`tmvd~_btJh)#+>LeFIY)@670(}x;48WDGRC9m8)&0b-&z?KCoYns6w^ z;klXSvX6*z9BGTqtLP;*J6f@*hHoj~!!Z!KN;+j48l}ak_bkG3_C`f*-M5fDiyD$O zl%Wm|&mcU9Jl}{xvlczgq{=O1f^DlTC6BZI%Fp~;LD~}wu9Yv1vH8j%{JUZGhy{>)X)=i!v8nBn}3($d7Qrl z;mw)FRgaMM7SJ)``gD63=E0PLmGGFqjaoeL7)-Z_8iR&rCsL*jWngHB4bPcp*vHii z8`h!n4^U4XGOa^Hh5-?Ok&iQ313pR`lwL%oAF^0YUF!srFC$=J=ywrbMR*M$9!{oV zCObQLFp<)ax@4HHnHj4&LmGMbaChiB1U{jh8A>PUcW{iWbro-+w2O7uX3Ezf%Ws%E zx1?IOM!^UDD1D#(q;@TNmxY#C;=2Ny?yMxQvvW(LfVkBkEZM=&n_h`Ip;ES??p!P1 zrs6xc>8m(r2!Wq0ShTLH*|q`Ioi*ttR7{2ITpV6T_z2;*0B~ON(yc|QZt^wueEs^84}nomZWZW!5-Fh{ za<7VG4=eoa_y(h~=WF=^5jvc*V7ZAwnjMJAp7GzK$v+@m<7?TvY(KfqURqWmUk4`J zUq6q1ylhv=6&SSLo~V=*s`Lum($G0ts!N8H6dx0kM3;5nL^!VMite8f;z7MkSM}fo ziHB5C74>pG&+r+_fWJ)8!*f(p*7Nm(312+VP^QAuf*JwM3VohlGC{^kJm2H$^Y!S2 z0ONwk1X+j|?hs;?ZfGYQ)}t20I%7~*O>35&NM|jp$5jS3s%5NRO!B#(0)z^W1dnBc zE4DB$Oc94n6Cyao3fbbgkb|)!BqWCn$ncio?Sr=u-U_@G+dmcHGBTh}i@*u86U&#j z0yR)DRb)pT(IMj^T0FGME(T2%o_RVsiMl1We_+1HV^^2AsuFxvsmqSUDjTb-<($uv zr>b;inplr>%0Er$b9Q*ZKTthTE!^z1&nzuCe!; zVqLdx6@Gcr@1s>EtqmL0qHpnmHaG|sNHH%dwh#0-$d*mmBf5{ZHW!C@DVZo5WZRnC z%BE4lj|v(T_pJ+r-el*STcVhE6f<$I7(PM9C3drUB?+)aEkEC6Bje8qe?hgN8-ye` zq592Ult1q;YhAXjvvixUFu z5M%=EnlLWr6i3p9Ap)1u2d#~x$NPj_!1g)fl%mVikhxzF_HN3lx||C-3eflMk#eD2 zUM}o}9QB|`vydq&)PdmSO^Q0aC&i>#Cgk!xz7y7cQO>Vfaa`G$(65UoH;Yq2Wi%}>~aI`At11o+9)&s%vdGhG2 zdyY6Q()%3=iU_2mx5c_xD)hmGP&d_wB927A7- zs@Jew^?+vg_VY^vZlWm_go6`wm*g#Zn*gumzL@2;nS0l9!;qM6l zM8Ji5oo2Wfnlvi`*C)*m4tu=^kH}3_`8R+YN+c}9PT58#k%;Br4#@^ui1x=(5d#zV zA2fzF+Go*X433;QV2yae;<3YZxDxgjK7LS9Ts~fF6y)7$GB{C*7KI0{10 z%h)M1n@K^QVhY5HJs-7=7UuK$cuvzt*=R=#xyUYb zTyEh3B6?5(wn$RF9|wGEXtjLKpU6+se{dwRNi z;J~d!wFRuMa}7Dcw9aznQ_!LX?09FH(u}eeHr2U+++c5amQ~k_r0R2}RwN?H`&_Au z5PAb`654>T#q=2aSLbT-Q`XY8RN?=q#@ODjuQY8&lP3`PzcB2qysn8Mg)NnLnOs5Y J2D{p|`@fmymNx(Z diff --git a/model/__pycache__/vision_utils.cpython-310.pyc b/model/__pycache__/vision_utils.cpython-310.pyc index 97b6ac52cae67722963ef6b29cb5ebb531934d40..f995df75d9b95b28ba4941dc7405dbaefc9fc15b 100644 GIT binary patch delta 1146 zcmZ8g&yUA`n*Takd0p zVRPq+75TiuT^s}6vz4TbfH&oq^1f z?nDYDdvPkbB#27e7*j~dqClj5%p|+&$B8KJz!5Bwh``| zB*)DhV^Y<8TtP=$B{G5>pf>_e_AS{#jBr|GfP67-KO-}uk(JQd4#|suai8p<2}#Ai zxpm`uH7B>PR}|ffOcbI{ob-kwui9A3RHj`eQhcO5q?IE@9v_H`8A7T?KTG$8R2Z=9 zEz-xdxILXrRW1W_P;FWU?hJ+M#>o&Fs-Jxo!Q2SY^NA@qpbX@C@tnK|e;0p{_viU( zg=Hxiv5~zXavh7IYu1dWM59%XrlZCJI72~%nwBoNv1p23v-X(sN`GnnauF;IqhQpU z+1f7Kn&2u&bu?SR7uNUwtf8rHU|v^4`Hprm+T6RQopS5M(gN!q{sj#Cf`XO0TUH)D zLam$r9e&Citvs!j|BRF;i=XLFC*Ks@c$j6o-9h?cagJ@nsUl-HKUy7#9&#$W;xVeE zui`B2PE|b4tKf!F)0ao#;&--mJ~8}VDDK;LHtVG4&Ag`a#4(j3c_28C)8Pd}<_BmP z1mtY-k9}_bInKg%px{@7Zb9opN2I>gm1|k6y5|IQn znZpDk`iG(#H0@H1Nq6HWY_HdtGInXDtAb%`q-3Ppo-c)>D6O%00h^GI3-nc+qrjpr zxWGXbkm(qe%VV6_2w4JEBs57d2@Fm!BoS0FVV$!H)h4)&wJv_3mMS7j=_q=Ev3}C) z)3n+d$hO}ak6KZ1Bz^C+9eIJ@3P)}133kCzGW1Smjit6{`k_o-*c5*>jzX0;J~L#7 zmR5Q)YI(yR7M8Vi1%;u5$fad!fm%iKzcS(vtOVcU&uo3FRxxe1MkdUs8gXQ4*0*OM ztRt#!YnC{|#_||axs0+Rp5*7`%U!^d`)W~zZEa6UAPuMw(Id%J8`zLN>z;;&fbGe`G(=M{w!biM-C6A5T4(=iGSf8L>3fj zIsWc!OdlavE2jLUt$fr7)uypG!~eZHY`?LiEM!xz^B-LL#vou-D lm2vsctgcGnPltgR59BjIjnIsTuaG4u;?9}tQqDLp@?Z5&%O3y$ diff --git a/model/dataset.py b/model/dataset.py index f5b7511..bd6ba9f 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -54,7 +54,7 @@ def __getitem__(self, index: int): sample = self.data[index] image_name = sample['image'] conversation = sample['conversations'] - # minimind-v的image的特殊占位符,对应每张图切分成10个token,和get_img_process中的数量对应 + # minimind-v的image的特殊占位符,对应每张图切分成M个token,和get_img_process中的数量对应 messages = [] # 遍历 conversation 列表 for i in range(0, len(conversation), 2): diff --git a/model/model.py b/model/model.py index 79fd7d2..6f8fdbb 100644 --- a/model/model.py +++ b/model/model.py @@ -326,23 +326,23 @@ class Transformer(PreTrainedModel): config_class = LMConfig last_loss: Optional[torch.Tensor] - def __init__(self, params: LMConfig = None): + def __init__(self, params: LMConfig = None, vocab_size = 6400): super().__init__(params) if not params: params = LMConfig() self.params = params - self.vocab_size = params.vocab_size + self.vocab_size = vocab_size self.n_layers = params.n_layers # image的特殊占位符,对应每张图切分成M个token,和get_img_process中的数量对应 self.image_ids = params.image_ids - self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.tok_embeddings = nn.Embedding(self.vocab_size, params.dim) self.dropout = nn.Dropout(params.dropout) self.layers = torch.nn.ModuleList() for layer_id in range(self.n_layers): self.layers.append(TransformerBlock(layer_id, params)) self.norm = RMSNorm(params.dim, eps=params.norm_eps) - self.output = nn.Linear(params.dim, params.vocab_size, bias=False) + self.output = nn.Linear(params.dim, self.vocab_size, bias=False) self.tok_embeddings.weight = self.output.weight pos_cis = precompute_pos_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len) self.register_buffer("pos_cis", pos_cis, persistent=False) @@ -372,30 +372,35 @@ def count_vision_proj(self, tokens, h, image_encoders=None, seqlen=200): # 查找token中片段的索引,为了替换做准备 def find_indices(tokens, image_ids): image_ids_tensor = torch.tensor(image_ids).to(tokens.device) - indices = [] + len_image_ids = len(image_ids) - for batch_idx in range(tokens.size(0)): - for i in range(tokens.size(1) - len(image_ids) + 1): - if torch.equal(tokens[batch_idx, i:i + len(image_ids)], image_ids_tensor): - indices.append([batch_idx, i, i + len(image_ids) - 1]) # 返回batch_idx和开始结束索引 + # 使用view来创建一个视图,便于处理滑动窗口 + tokens_view = tokens.unfold(1, len_image_ids, 1) # 在第二维度创建滑动窗口 + # 检查每个滑动窗口是否与image_ids_tensor相等 + matches = (tokens_view == image_ids_tensor).all(dim=2) # 对窗口中的每一行进行比较 + # 提取匹配的索引 + indices = {} + for batch_idx in range(tokens.size(0)): + match_indices = matches[batch_idx].nonzero(as_tuple=True)[0] # 获取非零(匹配)索引 + if match_indices.numel() > 0: # 如果有匹配 + indices[batch_idx] = [(idx.item(), idx.item() + len_image_ids - 1) for idx in match_indices] return indices if indices else None - image_indices = find_indices(tokens, - self.image_ids) # [0, 4, 53], [0, 54, 103], [0, 104, 153], [0, 154, 203] or [1, 4, 53], [1, 54, 103] + image_indices = find_indices(tokens, self.image_ids) # 字典形式存储索引 # 如果此时有图像编码 if image_encoders is not None: vision_proj = self.vision_proj(image_encoders) - vision_proj = vision_proj.unsqueeze(0) if len(vision_proj.shape) == 3 else vision_proj + vision_proj = vision_proj.unsqueeze(1) if len(vision_proj.shape) == 3 else vision_proj if image_indices is not None: # 创建一个新的张量来存储拼接后的结果 new_h = [] for i in range(h.size(0)): # i即为current_batch_idx索引 img_idx = 0 - for batch_idx, start_idx, end_idx in image_indices: - if batch_idx == i: + if i in image_indices: # 直接从字典中获取 + for start_idx, end_idx in image_indices[i]: # 插入vision_proj特征 before = h[i][:start_idx, :] after = h[i][end_idx + 1:, :] diff --git a/model/siglip_model/README.md b/model/siglip_model/README.md new file mode 100644 index 0000000..e9e0702 --- /dev/null +++ b/model/siglip_model/README.md @@ -0,0 +1,5 @@ +* 需要把siglip-base-patch16-224模型下载到此目录下 + +```bash +git clone https://hf-mirror.com/google/siglip-base-patch16-224 +``` \ No newline at end of file diff --git a/model/vision_utils.py b/model/vision_utils.py index 56d70fe..61e14ea 100644 --- a/model/vision_utils.py +++ b/model/vision_utils.py @@ -1,5 +1,5 @@ import warnings -from transformers import CLIPProcessor, CLIPModel +from transformers import CLIPProcessor, CLIPModel, SiglipProcessor, SiglipModel from PIL import Image import requests import torch @@ -8,19 +8,27 @@ warnings.filterwarnings('ignore') -def get_vision_model(): +def get_vision_model(encoder_type): # 加载预训练的CLIP模型和处理器 - model_path = "./model/clip_model/clip-vit-base-patch32" - model = CLIPModel.from_pretrained(model_path) - processor = CLIPProcessor.from_pretrained(model_path) + if encoder_type == "clip": + model_path = "./model/clip_model/clip-vit-base-patch32" + model = CLIPModel.from_pretrained(model_path) + processor = CLIPProcessor.from_pretrained(model_path) + else: + model_path = "./model/siglip_model/siglip-vit-base-patch16" + model = SiglipModel.from_pretrained(model_path) + processor = SiglipProcessor.from_pretrained(model_path) return (model, processor) def get_img_process(image, processor): - # 将图像调整为144*144大小 + # 将图像调整为224*224大小 image = image.resize((224, 224)) + if image.mode in ['RGBA', 'LA']: # 处理有透明通道的图像 + image = image.convert('RGB') # 使用CLIPProcessor处理每个patch - inputs = processor(images=image, return_tensors="pt", clean_up_tokenization_spaces=False) + # inputs = processor(images=image, return_tensors="pt", clean_up_tokenization_spaces=False) + inputs = processor(images=image, return_tensors="pt") return inputs @@ -32,7 +40,7 @@ def hook_fn(module, input, output): embeddings.append(output.last_hidden_state) # 从 BatchEncoding 中提取图像张量 - if isinstance(batch_encoding, transformers.tokenization_utils_base.BatchEncoding): + if isinstance(batch_encoding, transformers.tokenization_utils_base.BatchEncoding) or isinstance(batch_encoding, transformers.feature_extraction_utils.BatchFeature): image_tensor = batch_encoding['pixel_values'] else: image_tensor = batch_encoding # torch.Size([32, 4, 3, 224, 224]) @@ -58,5 +66,5 @@ def hook_fn(module, input, output): hook.remove() # 拼接所有特征向量成为一个张量 - all_embeddings = torch.stack(embeddings, dim=0).squeeze() # torch.Size([32, 4, 50, 768]) + all_embeddings = torch.stack(embeddings, dim=0).squeeze() # torch.Size([32, 4, 50, 768]) or torch.Size([32, 2, 196, 768]) return all_embeddings