From 1fc98152c00dbb72de6824e9ab0ccf058bab22fc Mon Sep 17 00:00:00 2001 From: gongjy <2474590974@qq.com> Date: Sun, 3 Nov 2024 11:39:46 +0800 Subject: [PATCH] detail adjustment & formatting --- .envrc | 1 - .gitignore | 9 --------- 1-pretrain_vlm.py | 2 +- 2-sft_vlm.py | 14 +++++++------- 3-eval_chat.py | 6 +++--- model/LMConfig.py | 6 +++--- model/__pycache__/LMConfig.cpython-310.pyc | Bin 1601 -> 0 bytes model/__pycache__/dataset.cpython-310.pyc | Bin 6234 -> 0 bytes model/__pycache__/model.cpython-310.pyc | Bin 16464 -> 0 bytes model/__pycache__/vision_utils.cpython-310.pyc | Bin 1923 -> 0 bytes model/dataset.py | 2 +- model/model.py | 9 +++------ model/vision_utils.py | 8 +++++--- 13 files changed, 23 insertions(+), 34 deletions(-) delete mode 100644 .envrc delete mode 100644 .gitignore delete mode 100644 model/__pycache__/LMConfig.cpython-310.pyc delete mode 100644 model/__pycache__/dataset.cpython-310.pyc delete mode 100644 model/__pycache__/model.cpython-310.pyc delete mode 100644 model/__pycache__/vision_utils.cpython-310.pyc diff --git a/.envrc b/.envrc deleted file mode 100644 index 446626f..0000000 --- a/.envrc +++ /dev/null @@ -1 +0,0 @@ -source activate minimind \ No newline at end of file diff --git a/.gitignore b/.gitignore deleted file mode 100644 index f0b6441..0000000 --- a/.gitignore +++ /dev/null @@ -1,9 +0,0 @@ -model/clip_model/clip-vit-base-patch32/ -model/siglip_model/siglip-vit-base-patch16/ -out/*.pth -full.json -trans_json.py -dataset/* -!dataset/eval_images/ -!dataset/eval_multi_images/ -minimind-v/model/__pycache__/* \ No newline at end of file diff --git a/1-pretrain_vlm.py b/1-pretrain_vlm.py index e42840e..aca5408 100644 --- a/1-pretrain_vlm.py +++ b/1-pretrain_vlm.py @@ -173,7 +173,7 @@ def init_distributed_mode(): if args.visual_encoder == "clip": lm_config = LMConfig() else: - lm_config = LMConfig(image_special_token='<'*98+'>'*98, image_ids=[30]*98+[32]*98) + lm_config = LMConfig(image_special_token='<' * 98 + '>' * 98, image_ids=[30] * 98 + [32] * 98) max_seq_len = lm_config.max_seq_len args.save_dir = os.path.join(args.out_dir) diff --git a/2-sft_vlm.py b/2-sft_vlm.py index 0a2b1a6..877b76a 100644 --- a/2-sft_vlm.py +++ b/2-sft_vlm.py @@ -97,9 +97,9 @@ def train_epoch(epoch, wandb): if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): model.eval() moe_path = '_moe' if lm_config.use_moe else '' - if args.multi: # 多图训练权重保存 + if args.multi: # 多图训练权重保存 ckp = f'{args.save_dir}/{lm_config.dim}{moe_path}_vlm_sft_multi.pth' - else: # 单图训练权重保存 + else: # 单图训练权重保存 ckp = f'{args.save_dir}/{lm_config.dim}{moe_path}_vlm_sft.pth' if isinstance(model, torch.nn.parallel.DistributedDataParallel): @@ -197,7 +197,7 @@ def init_distributed_mode(): if args.visual_encoder == "clip": lm_config = LMConfig() else: - lm_config = LMConfig(image_special_token='<'*98+'>'*98, image_ids=[30]*98+[32]*98) + lm_config = LMConfig(image_special_token='<' * 98 + '>' * 98, image_ids=[30] * 98 + [32] * 98) max_seq_len = lm_config.max_seq_len args.save_dir = os.path.join(args.out_dir) @@ -229,12 +229,12 @@ def init_distributed_mode(): if args.multi: print("进行多图训练,建议在指令微调后进行...") train_ds = SFTDataset_multi(args.data_path_multi, tokenizer, vision_model=(vision_model, preprocess), - image_special_token=lm_config.image_special_token, - max_length=max_seq_len) + image_special_token=lm_config.image_special_token, + max_length=max_seq_len) else: train_ds = SFTDataset(args.data_path, tokenizer, vision_model=(vision_model, preprocess), - image_special_token=lm_config.image_special_token, - max_length=max_seq_len) + image_special_token=lm_config.image_special_token, + max_length=max_seq_len) train_sampler = DistributedSampler(train_ds) if ddp else None train_loader = DataLoader( train_ds, diff --git a/3-eval_chat.py b/3-eval_chat.py index aa7c91a..52cb898 100644 --- a/3-eval_chat.py +++ b/3-eval_chat.py @@ -66,12 +66,12 @@ def setup_seed(seed): device = 'cuda:0' if torch.cuda.is_available() else 'cpu' dtype = 'bfloat16' max_seq_len = 1024 - encoder_type="clip" + encoder_type = "clip" # lm_config = LMConfig() if encoder_type == "clip": lm_config = LMConfig() else: - lm_config = LMConfig(image_special_token='<'*98+'>'*98, image_ids=[30]*98+[32]*98) + lm_config = LMConfig(image_special_token='<' * 98 + '>' * 98, image_ids=[30] * 98 + [32] * 98) lm_config.max_seq_len = max_seq_len model, tokenizer, vision_model, preprocess = init_model(lm_config, device, multi) model.eval() @@ -79,7 +79,7 @@ def setup_seed(seed): # -------------------------- 问题和目录设置 ----------------------------------- if multi: image_dir = './dataset/eval_multi_images/bird/' - prompt = "\n\nName all the differences between these two birds." + prompt = f"{lm_config.image_special_token}\n{lm_config.image_special_token}\nName all the differences between these two birds." else: image_dir = './dataset/eval_images/' prompt = lm_config.image_special_token + '\n这个图片描述的是什么内容?' diff --git a/model/LMConfig.py b/model/LMConfig.py index 140b23f..3113da1 100644 --- a/model/LMConfig.py +++ b/model/LMConfig.py @@ -7,8 +7,8 @@ class LMConfig(PretrainedConfig): def __init__( self, - dim: int = 512, # 768 - n_layers: int = 8, # 16 + dim: int = 512, # 768 + n_layers: int = 8, # 16 n_heads: int = 16, n_kv_heads: int = 8, vocab_size: int = 6400, @@ -19,7 +19,7 @@ def __init__( dropout: float = 0.0, flash_attn: bool = True, image_special_token: str = '<' * 25 + '>' * 25, - image_ids=[30] * 25 + [32] * 25, + image_ids: List = [30] * 25 + [32] * 25, #################################################### # Here are the specific configurations of MOE # When use_moe is false, the following is invalid diff --git a/model/__pycache__/LMConfig.cpython-310.pyc b/model/__pycache__/LMConfig.cpython-310.pyc deleted file mode 100644 index 22d8930cde58f9b708b9d50191db2690e55d9547..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1601 zcmbu9NpBlB6vs)si&f!pqE^F>!;{9;H^keKm~dzx~1uY{wdjM0@SB+@P6duHt#(?vS>CVf->8? zHU7Fr$P=f$SbQjVAo^865J6Lt(E+8d%~FY_2GqS+y0Wm#_$u_{_(4P)z~4Mw94(UHE`fOH|7kPDEDkV}xuvLd!nZ==2<194SW z#SZFgG8k<=A|ooUiyJ4+LHLOHWKe@!H*I+D-K8Hk%#u8Th=edg<{_+}w}6CPBxG-X z2hH)1lP8~i_4M1_F_cC5;P{Q(=at>_%KQq}&ao?*?*Q-2@jo(mhhwX{7?oLkIR3cv z<=t!}R@CVVBIvqhFJldW))m9EpXMWfi1HRu9-HtZaQGwY=E6|<01L5D`-MY z_Y3;zBH3~BZ0|AYZK_q+r`##4HHUSFwnN8Z!@>EcoKxz8gF7~L39z+gkP@NQWv93E z8REN)2uEC8*j%g>qWf>G)Rvpu(8EF{`IwJp`OvP!vqPR1TJt!aPGV#M4hMa9WpQR@ zG2Q1=Rou4~J;TN}IW=UG-;g_9sOq`$!VipUmQXe3;}%sQPF3>a~M@RtTB)&#v6-(<5~Y zOWjplJW0WJK+pLqFZXxp@!Ioko<}2ck;&ssa&9AzTop6qKDQ0dAI#$PY-C{tFOR0u zdP!c|$|x=3()#y{BDJ0_)xs*7uH}uvnKnAwb>7>+x>p= z*S%iFHt-PVdW36@mqEEl#0KyRX{9_Ui+M|n-2p9I&TJdrUWR6uLel}zAPw5_{sGV( Bi$eea diff --git a/model/__pycache__/dataset.cpython-310.pyc b/model/__pycache__/dataset.cpython-310.pyc deleted file mode 100644 index 93d01323da016bb05678cbc6a78a1c7b19cac0bc..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6234 zcmdT|OLH7o6~4D0Gu^FuNg6$Ehdc=2kVvr;LI7j(a*T;1o7j+eNN759d!&))rQ1Dz zO!a^%nF^Z}*$1f6QojJ!6g#R?Y}njlK@}^lwR2SxzH?_r8ja*6gg_PD)!%);Pk-l} z+stM?hUbG19$o$J*BSdKB}N|$iRbV|zXu2=c#W0%Aun;AHcEyL&61@ zo2f$CW3Z z-R8;(8|9{`HdZB;&+cW;L}!C2sz!0ph?QGywAwgQ5p@=qmPP;2ms^89#8E>WqUG&^ zvZHn@l*;k_YNHzazOtIFu%WE=s0pk=K$Jzzl~)gL`L(dI8n3Aw#H8Me{Xv$>1&!!t zDEHEC+`Jw(svUHkxlxU(&4yoZim;|U%wQ;D=~f_U`YQEyGr~+&W-v{^6<~~MJx4!k zh2?5c^L3BguCk)ARw=nOJzY;efW^x%5ga2pPH=+Y%K*hI*eTcd%Qftc?{BkvPc5%C z>)~=c3gz-zJ7}!lsy1#TSgvCO@QEjGEbGY*cWtS4Th5}dNqoDD0M9P)96$bX%1m*3 zGx(uSGEmhVxJ`qwaw*M5Xi^& z0Gz~B+NYE7eOlr!^@O444iByCs1y#pEgNML?BQc0lESBF@B$)&cg>i^Ji#o0#wMQ! zXU(xq{t|nu(6tf{QMURebL>|JMAW!$$OjUuYi{Hc6Eingc0R_lL#x9YY>7<`<&%`M z^(snk9sjRigNQ08YQNo&c{mveYD0v#6t9l2bgF_9AC;+9zt&gL<8B6E@#B1+Tim!~ z@y^5J%Rk((OglAvh_Ptv^$t9XNGFX`!0Nlk>sZ?r{L(SMeT=g%e(a93%3jBYUBl1b zHGhuP_QRX}CX-8OjP+XKji9F7bGOQ2D~7~xTYAxoN6M68qKHa7|?p-;`Ii`^W+ zfDl<`?xO;C?ie3;7RDE~*FxEP#||usgmieQ(VuY)5GpKbfpv|<5EPmSadcg4m38fy z3lq#{ZE|64nmFcD-Bgki_8jpYa<4P_eH=R2h#0A75+AKwTShlcxm2%T(shzFYC8Dr zo-L3P?r)7Pu%J$R2?)QR-k3uVsr7VX6YY!ayJpt~@1+tKyvIZi{n4Bq60a_}vF9eh zb#KeSj1X?}Io6$^8b~8G&$=01V`7yjnIc1}*DpyX=2?3FI)ja3xiXNmbf`yKYGTa0ItBJyT8jqXe`nM@_q$xOup>lapV)NeMVP4t1MzTl1tgFg({7;+-(=@<00o!w$k>s~2gZ^rj2A9`@a@c3uRM3Lr5SowGG!nJjet9j3{Wz?n zS{(NI`Fqs-VS>Ray zgn&p)J_y*J?X%=7h^tf-V7yRyef`l0>Y>Vv?woS!uy9EK$+OgeZzvA+)(CG7)~uYN z@>Kbab{Nri4>jM!Kxt7L9R<;K9dm0{-C5$?7}b^`4v>;%szo!+id zJmdQt?V#2#(aKVuMtx=CD5XB}3)IIA%F7OvaXGBje1BxMU!y9d-Xxh}1$?_jfCEzO zzqt>c>741LXVbby8HEQPqI7y>8UT~gyYlkYp|yC_a_(#X$t-o?tVp6{nJ1v>=s!+Q zQH=hI&dk1H6;;T5);5)-Ny|qG^oq3V*2^M(7L@-upfkNcyHB?gv|lOz;-wca`P%n8{p6{qea)KkRgw;zabNq% zOBB0I@IwM}L*Um!i^3%Gl}fY-P5;R~2p z%fy~9;37Th3uN~A&Fhzp_yV8IhS^?j;0x&3h%b4<1&{l z_Kph+H&_b^E*CdG8h9KF+T)m6(G=S;9)Iw2)2zz0s<*fZ%Nc z$Uc*MP6LLDWwo{MJ185sRLw}GE4b^5pE|AF5Y7fUL2#NPnEzoXJb7&2zmR0Xl_vI` zi%3>zgjE9aC*&GIm0+FVIzf$~PS7A|60`_52qZy75EHZsLIUy+a8Hbjx4vcG+xWA*v=qkX&6&86A?%}{tHPc Bm-_$! diff --git a/model/__pycache__/model.cpython-310.pyc b/model/__pycache__/model.cpython-310.pyc deleted file mode 100644 index 5043f059ff42598905298b9e90e921f8ba09207f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16464 zcmbVz32Y?Ud0t&zU463IAf#p$eE?qOAv z-Ry(&syGKzbVlOH8abY|6*-6PIB3cTGzj1vKmg@d5afUnzy^}QNhLv`*zsm^ApIq%&yVUXR~3-z16VfY&Yz{3hb`aa~h7xK2A5+a~p0i-^lm8h9`Aqx6mj`|5BqQ z=W?%N8K%*zHm0OUu3PF=8Wp*AyS3hQV_L5B-Fk1PF=HB^F#<0r+%|$jSn5UOR%AJl^B+eTv{sH5*Q$<`im91Se5*GOIa_np44P)SS+09zo3`!4hhgvYI8-JQ|!q&6%v` zQPeyZoJGyqa2m5d6PRxs%jfQMy;tsIeT`)^^b%xfTDxcD(*pGm2OF zhVefqyNRB~MQCs*v5RJi9Soh8^A)3I26o`wE;n*PE^x7eZjcW=obz8X9HZf3-i0*( z**D%^8K_=r2W=GLX0&RLy@l5~96U$oP{hWrF*NtE*@=}Hz7^X;C(aGs#N5L_YVv{c zwsFr~S@zOgw6z(k)bsszza9I&Dq{deYD^1nHC3}0#%RvP1GT!II)i=~rSA2xy|x}> z-JB@wwp58v@UeSlTJZhVZZnE}|DN%CPd~Xn=!H*iMWK3feXH4DyV367#OcW%-i0Ih z@U*%)EN)%L=44d`U6vj7C{Agv)g3hBG#}sG4E<)LX4q!SwA2zixC0@{!xPN62I_iK z1qtWoS zAa_#6{)YbuZUQs0xa9#>EU@?N0P7jpi^i9&1!HLK;qHVngJZ{zjSdJSu|X6dh#dz{ zIf;|ldlvh#eVD`Hq6WS8tesqv6JpU*Gye0Z6s3paB7)f6AiUOI4VUwpW?ZmKe=XDt z^Ax(IIqU_wXLHbPsZ;GD6GDz z!qq_!(kAq?E+n0Z>-j;mhX1No+G zjf?gT2&J!?f%7%H*(T07j79G2IbC0V;c4kNizmz$`| zC+;wxpih#=8N}bRu;}i}kvv0skFKCd%#QI@`4|V1Z^Km8JI00)O}}b<$utd?_#{la zgU?-HZ+z5((XuwIXo;;Ndz~P6+l#?*9dpV;TP&MN)b*{C;>yReYRw7N*r+6i7{5-FO z5=&#_e*m>Yg06YeJZ;WXs677vG&5OH;iUUq$z@M7(G;fu!RGUWyWtJtrZlfYnXJqw zEsy%B5ZXH5>Nevz?5oFQnAEy)6~~7Dh)<)gVc)pwquB5pF0R0nIB)QLn8#{|XHIPz zRZ)266Zt3bakKKTUPYA_x4FKr!p+oGVZ5dK6C^t!s2X?zq-DJ0>{vVYjswwR1}4a& zl7Q48u&luPGLFfTcdZ=67(^WS%7O68?d6~%vb$LW_YO$a-hN6{^aV^NY{pRL5Io(m zc`ZDD<8*)+(0}n;iO(;gLNfpI8it3 zr^UX%am`;3n?aO%Jo)>CDik8bQOG8+ag9VIYge^ANTwyL3K?bnZw2>m2sXBO~!I zdJe@nP-&?Dt~K~^)Xt-JA)Xs923Tg?pWxAQ;*&vs*Mv^A(En6o zq1A&9rMx^H6mV|^#k)2w2#{k*OU_FaXII+IXxUm>Nl!))Pa&vg5c``7$YM42g>$)C zLXF$bhui%Y;FuS*n`?r_onABAz%WIi)Ch7ATf6*-B1>xQ@g}a)d_N=0O5YbDnBCS! zwkR2v^S>b2gC(5|#(E4H@!tC@vSgv^T8`SpDwI$WntYpdH$OTFd|KMJq< zU^p%MoLIHD(kT_L@#~d;WvkT+)d=#v&9C?&7J{($`A`<;BI?B;C_*sLnjZeE_&=Hb zm*vg_#6D?P(Y{@OX~a7FQ2UaW1mxvaqQ=o5;XaXtyyADkNFWm99g%B0E>Hy&FNscr zS)6U&+qPxz36|Isz)J-TM!y85Qf%vRsTyJkAGon!J11wz%%ocadJ z;{XR@{M~hTiv8!}%5a{2FitKg>{)wOyf9ouKX=arlr1Ivt`k6PfVrbG$?=+H%ULjv zqgQ+qtXdkL-ZO$CzjgRXk_QUIv%{sN#4CAEDZv|dib-)a>*1q&8XD(<;+~BbV|XSh zqCc-5%dU!=(TQUwfDtQ1z~UTHDSQ!M`Gx!rXaWM({5*6lUp-c}d(XfuYqpn)o^ zpPHA_d^BjqkR96ic#K1EtDz9t5P8b@r8LJ9Qy} zsPRo-{eC|53Wu_8TvKOx;YOs}7~(Ll*O z)Cbun(ol*6^YgQ4ASe_EgJ;yurQ?cq+^C5^N;m!;<+`y08!j9NE*wwL30*gI;b_;5 zE3TV7xTBDkE`?!mNt@odGI@<+bG7s ziiRqW!1RSb11cHnmo`;kpU_4OZ7};dT!ZsQ0w5xNY=9KtNIIGXe&k}-0O0J`xKcCT zClB9|mpB3-Ag%CC`IY-*Ktb-BMab`$)Mb>*#WBsi%w9s#Ka4AA)TeQL1jjQfWD{u` z?iM;!G>85KukMbTj#6kphQ_FlqA1F&2C}Nlf6d&U9b4&r^)&(hIL`u(iyVct1X;$z zj4kLV$~)*N(OSUf^b_%Jp=NxdX03cH!)iP=tU=)sEI^s-ZStqFX?40e_m#==%oqXwH0F@bz_|6H zJ32IWfhl(DWVALTkQ?xrMa5Yk zMf-me_0b#(dPNE4JeXCwTr|t@l4g6ePy6b1bY6K#GgG1n2RwlbU>Bon^h^m1z;OYO zu{L+aUmReeVX9aesCP7|zt0hX(aq*^zEROuXTN`e?qiLo9bTrjjV67(Yd*XMAoZRv z->rTW534`Q5wh>yW5pr9`FYg+5)M)Ib#u|&o*#YP;bog3!5KWW&*ASLpi;58>0pJt zbbxd;$J0RspBM0r;LrfBA+?myy5v{GSJn68d?1G1&lx{&?V4XQV-F+>qU{uk`XEPP zYzYQsO@6~sub{1jGT$kq2f`Vz zwY5`5&6N6C;4!iLQD^T|@a`IO$LtdBt+P3~~~fcLJlXJ@)B8 zUf%mNfB*NXm@j$~1?aUzNg%-qU>DGavAm#R-An9r3q|U7#S+w%`lngVnWi%J>oo_4g=QweW>hQ-gaF1H=3L8Sf^E1qHngbbi3L{T#c zd_{GZIzjuIujtQB^Y9Y;>jUjyhdi4_WuuaThA4*gjTFHEAfGl9Q*#{?AFQ?GNZnyW zNhFdE(f~tJJhg)xX$h!f0FG>eWgAtE;e4zwMOnk6sAJd0vC#vgHj#jh`q%>_`Dd+$ zsXRy-iP$5cw2vz5t#)d)+J~9?am?)=j_3sx#DO*V)8G@b=B=7lhbXIA4n$iS2dz;L zhxNe1z4PF#H3#oeQT%TXFyi*qNF0omsHLSHyoglnM}kBOmG7fWzm6_{0!J2`sGG&9 znxk0z{_6DmIIQSlhI>DRiyKB5RVx5m1F#X%y#iRmORI0Z`SKVp5Ct$H)red66>AGO z2a5F$_*TdF1c%bI@5Q-0uoAU9TX0)G_U_=NRIO-1P% zK9Q*J*KtHAQBd1gO)p>kpT%Mk2Gi(c%inb#Y#)E0kDVYNabII3QTh%9Un22|4njas z_76F2;Kb1`8VR7XM_?`(RGM&hP(6tgpku2LZK52ZFzL~xTFrNW;!qaEu~Leu+DatF zFVvtY%R7|>-qoa-ls6EQ0ZEqh_4CHn=aN##0!4WphM!CwdV)%X8YL9(RG@CDl;QY* z&((ex&W%1)&#)@vd7K&DFnji>6>rXgs7dU%=qv$#1e~E$Bq@RK5ZhVAqUZomgez!> zT)?QR=o3%{-QwzRvJeG9sh&`2^!hl#h`K#A=^@R=vpPzrg_tTm9(j zRv1DjP!^_z-XMgt9v(gQx7g|1D3+_*C|bgG>a`)Z<92*AEdrUY>8Oq<0QDLopbM1cic1hw zva)#f@BsiN^{%x0tAnjRvI$l6)lf%wG3!j7|ML zJPnJ&h)(082peVq?Vq3|sz_%dx8)Sa_HYV(IkZJ$}-XgG(*kJR(T8G1YN{ei!x znt{CY{ar%e!UMmCL*kfp33+qw$-1dn|NeTPP;T6eJ|d&X6Mb)vnb3u_hD@S<1fD_2 z_0xz|Y@}Ydc@to;9jLeYLL0^CkJl~^o}+hB9BBT;%ED6?!&5%|Jq%lSED1WndV_zQ z8Xo~1lt2Jkdm%8H3F8+ZswtsIS^Yp<$)2nd1Hjj<$A$2O-C=d=fh8zGT9@q=DCjcg zr8Vie@%@KcD~S9pUVRpYcwBX8PW@9>7Fp?u5sjRtv`SPgkvmDINxfF9udnh*Awn)n zC|i=2Tkxq7oid+AtJ&Jwbvdd`jcp%?ZGDp5eT)3$Oxk%@T`eal0=hHqECuv z)xTizuUPy`6bQ-E5Kv1jo@KGi;?JWSN#oiUL3uUX+8C0igl|vCjLfhLRRR2fdVN1!2X8WviB1+3P^Mmz}#>I zE;AL!ON*axM~D}`r3Rh*WQ?`}e2kO^qU15BB)Mf`clh3YUitytjPVkAL=vvRN8~2) ze#ALurI5=yNCgrk?HxFp3J0h$tzMHUF{41-x^A~8=R&*JT!VTKq6|~M9YXhm!kb<7 z!^0j)>}WMrb#j5Vlt*zhHzw56?5> z)Cuh@seo?#N~4gu?Tf)8f@G!JW}{TF3KlX^E2C$F>gWn72U9@}8Lwqnbd^1$FoT{a zCC3%v#YR=v>P%PMPYL1HK)9#U(ir=t1;l(~v8-3_1H?ebizje0v6JpYqA2cL)Wtez zlBF;KUNNImZ0#BhtAS+c@CV@nu}t_@<2>B=%uaxB6)1*+KM(X{cFID~WrQR1!zx`i z{!}a;RzX}P*ak2OhtqK#3}Nnq{DFQx8qaVX9J84kIR>mhCv~*?MtKd#@$7Nt5p;_F ze7u0OP0udMMXD#9Pau#=+u>02dmVo_o#D&A90HM$5 zD{_S(rurGqz4K`2OoB{=pb$ScJR8twoTKCZgb|cD=GB`FvHFk4=XQDiKu``U3=0p> zb97OW%kc*TSP?QqcbDw{Au)g9tomnIcpcf$GWQSRe!Fn}nF|_L;c}2&il^zlKQj9( zpZU)APk%}$2wK{qV@Gi?;f3XPl2xdgBJ#W#(k?z!3^Q#wA21nM@e}BTyfCC?YP>J< z;6Nz6faVt2gfei}*gii-Y5S7kNCWs^C*$U`Vw1hwb_0L)D z5D$oWlxU;&#eEztN+3m_SDWmZlVmNVoYLLwXFwOWI94eb^Q)&GUJ_kLhPK+{V*rJ! z*95-)QI14UwEA}}$Y`nk=I1V_6+MK`eFaPg2JxfKE?$yVi44+WGOvX;S4{m53&9*z z(5bbF0P5yZ0{j*3)ko0>zO77z(~(~X@L)~`gIWZ|i^!LPmkd$Vv!->+1yr6j=dA77 zqjq;f&XWvsI7H5K!D2{@82L=QmO6tFAMAEW2}ECDGYijBnjnXWw(^p0htxm~Igz{| zP*9Iiz!)DpdW=y|+D9?os{SL3|Ab<>m=>WT`>-MUtDz2yPyuSv;I^p$#bSoV=TM}s zwxra5!A%w^K>S-=G6$TY-6j76T{>LR8ZOFMIcDiO56ade$_emxVS-fd{{!@&2)I%s zQIp8J;b#;Kk1$mh;EXc1V?b&lH(1u`OOPSV5jJ6`T5$K z-2JZZy=EYBgf_3F!HSH1bqC38uwG&1UNh8HShR%&O1x(w2uiS8E#5t~Q!gWn$z86fC8L11(PlEAoM(J!JHnfFAsKueHgF;vd3`Bk_Eu~5j$*)tLC zm?Mec9)yRqSVlF7w-9&964;Pw+{SXlHsLc^L|Ld)Sac_G7Yw+S7;{0^4i+Mhg)YLS z>R{a01T!-Fj#z_;bvOv%RJxGl)LTs9=6jb5D`Oxr!mP#l!|Z<#`KB>{^D!Z(M{Fct zc^3Ug4o^gq45&;g?;y{H3BD~WliP%r7}7xsZV!|zqT7ave9vrDkW=9wV8TfOeupH? zSDrPtKQqqa;yduikm7+PZ^7mt+5HE#W0}4sZ3w%ttKf@BEx7(^n4~tdI}pL%Mksfy z-x_oSbrbbU1T!RLpX?{D3jEx~OI__L5b58j?SFxAdUp4<)#PLTp()`FF zQMkey?Iw`M(UdhO=keh>vA8_`JzHE4y66}hsM?NIwGkk&>tJp&xpfOM&u9KW*6AQ; zVa)ee(u$49R~0!QasK_?(UPbV>lpD-+bsS&3ZzfthY1@ae++(8rOh$*Z`tF^Eb6Fh zRI=sL-Wkyvja(l&&g4yj+!h0m4 z!L!P9o?S$KA2M!X%sS8+HFMrNX`Y8a4;eQ!Ypv~5-*ee6DSiunVFCwM7)KKll2!+` z|7$p+^XL}9X9qu74ic=06gNsJ$U(WU;~u`V9Sazb?I)pyz|{2nf}zch2|35~y(~!( z{yiGR1qcsfMTiebMu)QKt|K+{`N8XVTz=QFQ;ZpgMMf$lW-*hPNZ^A{&)&6n3Rgdk z(U68Gqe1#Xvu&cLAT^b^h|Gct{CzIU3UMW{vBi4?yiM8z+R4#JhAf}S671YHgv%*& zTl5KE+!8{~2 zz#2S?qR0UlPgYZU*#A{~MGL`l(NG}HE;Kmj!UdZ{sqFQIM%wgv= zqJ5Ed|Bc1(vY?bxw@{>&RRk>~AJaiH>rL$f7cLO6^j&r-`N&j-YdY-t{>W!7lRnPj zcR3V!p@H94-J~;A(Lc~|d}Nqu{-F0lPs2q^niF&=vp+=hVkZ7F`a~bWBc$rlf7wLN z7-+i=^i-!AY6jUz)vm`Z^kJ;vq;}-$=`2(>OR%kWpx4NDqT{eKgxYJ zt>TD&5pVw@VUmPHc935RCv$cONMPb89)|kcNhvy(c+paVp3ERc$O3s{7DFGU{w!-V zQpyP#Q6Up($(S62nORW%o>2Ez4%TFn zKLwi+`tMa3R~|sgdQ8Aq0cj(b^sWtdncAuKKZan8gPep4*pdwLZtURqIu$S$%3+?r z*MZ?@q4)I8RO|)CU9)81nMyJ}n%NqDtFwSFf8LOkbOvck2n5yfzEV=h`wYobK=uH& zy!uA}5wy@DUz1E*_B9h%eQUHX048KFjL9s~(aMb0H=0h|=+}?#yAM7qTFWybV4nmp zX)k?NLo(E|pC}dmegP+K#MrO&TauUn_L6*osL8m0lopsyir>f#)N(~D2VpWB3_$J3 zkfPQP-^5s{jWD~`8D}{3Ef!y7gHHHk7#=Pls*)_guY>{YZ~&>m3~=fPtyB@sM?{IQ z-cq4t8X*5N#t&b5siig#J~7RS7-&?*8j^?-oHk)v*}uf^R3a@Jkp2XHEv94x@g=_Q zMx2`0@Iy7EF&$3qOug*e!(Sx_65m+2ijD&-g^v33mTS$Kvq0M(&P+-=X)Z~Ar)OE< zO9y@GzvY}3G(KUwjE0?nmV}i&F5-xO2nAsl;UnZJEExRG3vAa132744wI$6U)II>9 zbYK%$c94z^2GC3d9k9_l#kUMQr!1?p+i54Zo2 zluAp?`}N`8zaFaJ;aX9eJcSkE&sgDrJw7gl#`@yd5$fM_7C6Ia!2r|Br4@M0&hSl- zq0rLl@(SenAupPZ=hCuH$J5y@ip!^o@yFRX4S8`{$3{%Hj&d@NpwS6Rn} z)p^_uZ0H!o*Ln3V7QeybJ1qF!V^p$__*3tBZOp!a1vDb69uL1U<9`|UKF@W)T>j76 Y^HtAVtkui)i}hOlx%zAMle1I*7sXN5f&c&j diff --git a/model/__pycache__/vision_utils.cpython-310.pyc b/model/__pycache__/vision_utils.cpython-310.pyc deleted file mode 100644 index f995df75d9b95b28ba4941dc7405dbaefc9fc15b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1923 zcmZWpOK;mo5aupPiJ~a^k+e;bT>3~60=11aK+!{x=FwMx)ClTBQ33@)b5|5)K9t>+ zVjI-S0h%6y{ssA%KhsNl+Eam^`~!L^P@pp_+DTIqJ3Bl3ns2|EsRsk!g73FquV+ip z{=~uSpACb%(BuRJvluC?2zy$Pl18+&BfE4Wr*tEipxrJyWjE>?ofKZ=nLWLzXL=G5 zZY_Q0oDx{!vd%Lr>a#BMpbuD&`Ops-*|Wm_8Qf>>hK?RQc=Yh`V^ODE%3A0i4!^A# zFZ9}8o)!7@)m(pNFtZ`i?!z+4xbEaxRSTXHJhA;8=n>ijX!08n4M01f;xeNPqO2up z$SGM+MO)2oXk(A)9sq=S`(oaCb3Sx-wUdHZ$2YcjN_2AvNAW9lYo4pE{Y3JuX`<5c ztxt|WK3^#fo_6hO@W$$!pN6hL$MxDs)MY#sT!|#FI1?1C!bO4LLMIA%m4X8zR!37F z`No47?i_2^=(;yOf7O?k2w>CSpY4q6lJCqU7dzuwQe}sEbp+i`nO8Y9wl!}-&2yR8 zRXkI9A-AVTdOhPRZjI*jKAa~vL3q?5KK!2R{9}3|hG0(7-}4U2#+=sRjhag5hX)gcLbODG1m^duDfYn&ACu8;hHT1!NB%vI#bfT)JhD{yoSj0 zG6OD5j9vkw#H8{m!VPSPh^sLC4`4jdgNDH6?;tYkXX__xL>Pd5=0W5Ea789`VxLg? zb~|z=?glw{a{<42+{c@dH4VeU! z!NP6a<*=c!s-xB#=Y)J={n%f08@uVkxn4t->x~Dqy(f2@&T`|_BG&hi+z$LFHnA%2 zrX{1Nkiz)?08ROm&F3beU(n^HQ)}TjKG4$LO%dSvhdQ=P(BwLZ=fHA9C%`k-3)nz_ zN8L%LQ(iH>Q6y5udBogTk`PT0kc?b}72CWXAS-7eR`G(Y4;JWq-lmuDE70?P*!6`SQ{cv<@p z{}nh@rCbM)f|VTAqU1vAcT|1AtNb`o7YPz$y6eGx%=a%bDZxv9WyBK*Lmu-(*q^>y ztpl_0tJSo&3tnkgj*}_Zy;&un%s9tTH393YTBKv$t?D=v3Dd#rDX6<`1<-E^p5;uOoZIPqfN!cfHW1;)J|e6%r87 L{NSkV4TAo^Qzs7C diff --git a/model/dataset.py b/model/dataset.py index bd6ba9f..d93de28 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -178,7 +178,7 @@ def __getitem__(self, index: int): image_encoders = get_img_process(image, self.preprocess) return X_tensor, Y_tensor, loss_mask_tensor, image_encoders - + class SFTDataset_multi(Dataset): def __init__(self, json_path, tokenizer, vision_model=None, max_length=1024, diff --git a/model/model.py b/model/model.py index bfa4c65..726ad28 100644 --- a/model/model.py +++ b/model/model.py @@ -326,12 +326,12 @@ class Transformer(PreTrainedModel): config_class = LMConfig last_loss: Optional[torch.Tensor] - def __init__(self, params: LMConfig = None, vocab_size = 6400): + def __init__(self, params: LMConfig = None): super().__init__(params) if not params: params = LMConfig() self.params = params - self.vocab_size = vocab_size + self.vocab_size = params.vocab_size self.n_layers = params.n_layers # image的特殊占位符,对应每张图切分成M个token,和get_img_process中的数量对应 self.image_ids = params.image_ids @@ -373,12 +373,10 @@ def count_vision_proj(self, tokens, h, image_encoders=None, seqlen=200): def find_indices(tokens, image_ids): image_ids_tensor = torch.tensor(image_ids).to(tokens.device) len_image_ids = len(image_ids) - # .generate时,在初始化后直接跳过 if len_image_ids > tokens.size(1): - # print(f"len_image_ids ({len_image_ids}) is greater than sequence length ({tokens.size(1)}), skipping.") return None - + # 使用view来创建一个视图,便于处理滑动窗口 tokens_view = tokens.unfold(1, len_image_ids, 1) # 在第二维度创建滑动窗口 # 检查每个滑动窗口是否与image_ids_tensor相等 @@ -451,7 +449,6 @@ def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch self.last_loss = None self.OUT.__setitem__('logits', logits) - # self.OUT.__setitem__('last_loss', self.last_loss) self.OUT.__setitem__('loss', self.last_loss) return self.OUT diff --git a/model/vision_utils.py b/model/vision_utils.py index 61e14ea..adc6a2d 100644 --- a/model/vision_utils.py +++ b/model/vision_utils.py @@ -40,10 +40,11 @@ def hook_fn(module, input, output): embeddings.append(output.last_hidden_state) # 从 BatchEncoding 中提取图像张量 - if isinstance(batch_encoding, transformers.tokenization_utils_base.BatchEncoding) or isinstance(batch_encoding, transformers.feature_extraction_utils.BatchFeature): + if (isinstance(batch_encoding, transformers.tokenization_utils_base.BatchEncoding) + or isinstance(batch_encoding, transformers.feature_extraction_utils.BatchFeature)): image_tensor = batch_encoding['pixel_values'] else: - image_tensor = batch_encoding # torch.Size([32, 4, 3, 224, 224]) + image_tensor = batch_encoding # torch.Size([32, 4, 3, 224, 224]) # 如果图像张量的形状是5维,则无需添加额外维度 if len(image_tensor.shape) == 4: @@ -66,5 +67,6 @@ def hook_fn(module, input, output): hook.remove() # 拼接所有特征向量成为一个张量 - all_embeddings = torch.stack(embeddings, dim=0).squeeze() # torch.Size([32, 4, 50, 768]) or torch.Size([32, 2, 196, 768]) + all_embeddings = torch.stack(embeddings, dim=0).squeeze() + # torch.Size([32, 4, 50, 768]) or torch.Size([32, 2, 196, 768]) return all_embeddings