From e89ecd2e056212972a3274605ad1c9c5126330ce Mon Sep 17 00:00:00 2001 From: hariishaa Date: Fri, 5 May 2023 01:23:01 +0300 Subject: [PATCH 1/7] feat(metadata-ingestion): implement mlflow source --- .../app/ingest/source/builder/constants.ts | 4 + .../app/ingest/source/builder/sources.json | 7 + datahub-web-react/src/images/mlflowlogo.png | Bin 0 -> 19569 bytes .../docs/sources/mlflow/mlflow_recipe.yml | 8 + metadata-ingestion/setup.py | 7 + .../src/datahub/ingestion/source/mlflow.py | 330 ++++++++++++++++++ .../mlflow/mlflow_mcps_golden.json | 148 ++++++++ .../integration/mlflow/test_mlflow_source.py | 106 ++++++ .../tests/unit/test_mlflow_source.py | 140 ++++++++ .../main/resources/boot/data_platforms.json | 10 + 10 files changed, 760 insertions(+) create mode 100644 datahub-web-react/src/images/mlflowlogo.png create mode 100644 metadata-ingestion/docs/sources/mlflow/mlflow_recipe.yml create mode 100644 metadata-ingestion/src/datahub/ingestion/source/mlflow.py create mode 100644 metadata-ingestion/tests/integration/mlflow/mlflow_mcps_golden.json create mode 100644 metadata-ingestion/tests/integration/mlflow/test_mlflow_source.py create mode 100644 metadata-ingestion/tests/unit/test_mlflow_source.py diff --git a/datahub-web-react/src/app/ingest/source/builder/constants.ts b/datahub-web-react/src/app/ingest/source/builder/constants.ts index 8d41c3533575a..b82cd1caab74d 100644 --- a/datahub-web-react/src/app/ingest/source/builder/constants.ts +++ b/datahub-web-react/src/app/ingest/source/builder/constants.ts @@ -27,6 +27,7 @@ import powerbiLogo from '../../../../images/powerbilogo.png'; import modeLogo from '../../../../images/modelogo.png'; import databricksLogo from '../../../../images/databrickslogo.png'; import verticaLogo from '../../../../images/verticalogo.png'; +import mlflowLogo from '../../../../images/mlflowlogo.png'; export const ATHENA = 'athena'; export const ATHENA_URN = `urn:li:dataPlatform:${ATHENA}`; @@ -61,6 +62,8 @@ export const MARIA_DB = 'mariadb'; export const MARIA_DB_URN = `urn:li:dataPlatform:${MARIA_DB}`; export const METABASE = 'metabase'; export const METABASE_URN = `urn:li:dataPlatform:${METABASE}`; +export const MLFLOW = 'mlflow'; +export const MLFLOW_URN = `urn:li:dataPlatform:${MLFLOW}`; export const MODE = 'mode'; export const MODE_URN = `urn:li:dataPlatform:${MODE}`; export const MONGO_DB = 'mongodb'; @@ -115,6 +118,7 @@ export const PLATFORM_URN_TO_LOGO = { [LOOKER_URN]: lookerLogo, [MARIA_DB_URN]: mariadbLogo, [METABASE_URN]: metabaseLogo, + [MLFLOW_URN]: mlflowLogo, [MODE_URN]: modeLogo, [MONGO_DB_URN]: mongodbLogo, [MSSQL_URN]: mssqlLogo, diff --git a/datahub-web-react/src/app/ingest/source/builder/sources.json b/datahub-web-react/src/app/ingest/source/builder/sources.json index c9db0433b3aae..70c263e421efe 100644 --- a/datahub-web-react/src/app/ingest/source/builder/sources.json +++ b/datahub-web-react/src/app/ingest/source/builder/sources.json @@ -174,6 +174,13 @@ "docsUrl": "https://datahubproject.io/docs/generated/ingestion/sources/metabase/", "recipe": "source:\n type: metabase\n config:\n # Coordinates\n connect_uri:\n\n # Credentials\n username: root\n password: example" }, + { + "urn": "urn:li:dataPlatform:mlflow", + "name": "mlflow", + "displayName": "MLflow", + "docsUrl": "https://datahubproject.io/docs/generated/ingestion/sources/mlflow/", + "recipe": "source:\n type: mlflow\n config:\n tracking_uri: tracking_uri" + }, { "urn": "urn:li:dataPlatform:mode", "name": "mode", diff --git a/datahub-web-react/src/images/mlflowlogo.png b/datahub-web-react/src/images/mlflowlogo.png new file mode 100644 index 0000000000000000000000000000000000000000..e724d1affbc14d53f0ec8d6d5304b8aac1dd4f48 GIT binary patch literal 19569 zcmeFZi$Bx<{{ZS!N>V6=a?Q2m79scOf)H}QjO3R4b#AjMlFBXQewmUx6Jm^!3d#LC zj7{#$CfBhs=dI7>`#a}vI6WSfz20xH*Yoyz-k*DI_)zQY>C30-=;+SsXy1KAM|X@w zNB7s;QzwCy^{M#<;J?58A8Bd+HQp&64t!wr(YExbqhn$_`ui7MW;Q1sogTZ+T@91q zY~mDSI=5-V;kt2~82<_(|?by)A|$t3?~2(r9j^gSx3csl;gbiEXZkdFDt znS6m)5EJ%FTOW<4A2PD{FZ^)ok^vPd-O+XyvA3(WVzBY%t3+(;a4^`^! zGJ*WUzS+8zQgZ+t zcY7)ZwX+_1vz*|h`4?1t)^+pcj2=ef81QKFsU%gbUzB3+7suFPFvp+g?OJXYwQGVQ z9zAue1+RUGm=j`qRY`Y9ln#xJ_W}T>Ht{i{7$z8WlD!3snTkxiThUft_S%~ zDFN$WntjIqf_)UM^jj|?@T=9mKXzb=Asz>wtXdi3jP6nLZ8_Cvl%6Fh*GE(6F`o2C zPhM|vvyC;{Fsw1g7mYZvkmgY-!1C>#iJ!{FSNC3|t))D%B~~%{yFBva#k$B$;iduC1YZfCpb*>?YExHhnyOcQy7M zOs@ir?>k!FjiRWqd0`rvG4e+|Z1j>n+8buSev8cm4(2JBNDwv6vs*!p{CW8%z8cNZ ze%P(?sZKQ&Mri8NCx~UmdtaozXJ22!#^@kAR+K``)8?3UAFt7!>A4^Edaf%<>gqRWm<<(zy7> z*yJtu`v}B06AeBS%b6QiG(Xd;XuW$NYH>(H)7^X5$v5!SK=ole1w55^WE36q_m&PE zBBEq?Ng!{p)@6etLTAxXu!&)2fz=EL%8LUya{Bi8=GWaFZu%|#4p|s0hlxApIYRk} z@AVwZinh(89^sQLvt=XuGXxX=@t`SQ;K>)X<{C-lDsqv(oF!IRY*@73iwRhMD<#nh z@naqB(H^C~Z+-+qJu$~f{E@~E`UUlR7Gfy+O_EeS>4@uZ4WgX1GWu5or*>|umjQx| zsQ^~#{JL1SSU_-3bNhWA4D@JagV4yLf~Q{;&3?I)I;V=j9xcvUUubD_dP){NIE#@u zTJS@xW=yYoWXAHf<9(-&&g4I|YGl6oS7>7XbhCD?#AGY@2*hXmOta)hM&f$DnBK!c zkQ%TgeQozWEhRsMSb=b706vREFI0PLeokc}bqcz{{5}yjjwG-1id91b;~?Qv zpAYEhGD(NrZOgEI=fSFmwKeI92@LpW^~Jhq+hSd`Z!35lct60sLhuUv=57ZEi_LH~HM*giuI`df-GNcdQIr^B}CxWp@XWwAl?s&0? zB>U7I{UWEJ`mp&a1~LJJIRkvuQ?gB;yuvx5P%Pd21^h?A=LP}pbF|-Uoo?3au@3x5 zsQ!b2uy*`!0rdTVOg;Y)2*&A1-ja2%s9|0#*FI_mMNtGW!ny?b8^`c^1Yqyz>OKQR zV@Y&U#&eF_&H^h zCQ&LJlCSNPFf$8E1MHj|ck9&q+X|<0J*BAj8^=e&bfKX1%Oda^9QNvX9-w#X|7>_G zqVZ1DE;oRuBvb`}9uQ=OV`$GHkSHBp@p1Z~pnbz=zgAV49>_@PC%9)&k2paSy#(;( z2m@x;NuCBk*RxT)E~w1nQq&3ofVb`LJVqnMUf~HS6FNFWhEt^Y&q7)NoANAtShnys zY2t{2fJ|+l76HXEaLGTryq>AiTY}h#)dN!h72G+f*OSJ$6&OWO`it&%4u5^)l?!6g zy{9q1lxA%|dS6^3uGwH1>F8GITc^&l?b3v}CTA5I5tf-H)JxA5?s7C9QGb{*qVYHr zyN8RWosM+9l#vG69?`{qaoJVI`&DXRAV>^Akc~cwVFE?oICfBO;saa-T{)yX3&tr% z?1H-F`AKXlAPz?mJ_PVQwapmef*o4s<|tXHUPF24z&Py$`>~L}0Xfw=_1q9HdU^6V z^aCNOAjE$E*|veqv4cAp@G-h)N}4_#qkus1wR=XiLUY94*)E_wjH06L#5Z1EebDRs z2Nps4Ao~;0w~WvnfT<76N*1i^fb{}?Ljk2qb5Xv|A>iq7mu-WRLK%qq>oiYr<7gn)~3`!ewR7vN5s zc%_nQ+Xe@>ob?ZLk^r~V|CY5o^g-90`Vs0+%BvUe@c_b$EKUNxTFk%XhNO;$W zV{o#zh7rmGkY|v~#6qPVJpLhKMm$5|ZW{$P4JhKd*8xa*80#~TM;%)NbngE8(mh5A zwmLyM&t-t*5EjN_*P=raah`xa94pLZCcRg?a`uhlsj}GCbLvGAbNk(JLh5eHd2FQs z^!@@CF#6J$Yk4MT$qg)}vHIP4LKyYBj9%Kzgs_fM`HV**-Gau{u2`$&cHv>rID#4Y6=K z!>k^=SbAcf2)x+SXX_D;O|C0Pasfc;xyjda%oJ`pt0TUNIrx;DN>GRy>U~BSHGqr} zb9XYll>u8}G3?>PGIMzIxjU5k*!BGItL;^s1Tx@9PI9npGPBp?}NzoL7{E@i<;FhV6W@(vVvM>_Sa>z+x zGPX_<6t1(oID>R-CIWWC!8EHC1b-X(=_e<5=Yd$PUCzN~c9?}wJrXgd{5S8*GS84S z)-N1EBE7GD_^4I_sxfK+z^~3aMq{YKc>&P)0?_a->)J${w)bnpS-1>p_Qr#-xIpzc z;ZGcn7}s7d5(z1m^5ZeP^eqBnp(MGvSl{2+_@j{*wLbzio0$p_JUjZEDbmh{S^j289IFuV~0iAfbqQ1SCN#`_lH6tb-F zp3S)q&9I-tu-BL1I@HHBs^9fMm`;-AN-H277$;pB zIXs{aS{f0uEY38Ck@q!9q0Vx(+t`NkqeQMh2+%n~TOn zYp$!S8S;(L`Br}~7hztyV!uh>4)hQD4=h=?0}A#+i!LR^3mU4T*YUjy>Wdz1k6C;` z7>kT))hQ3^Vt0M}@PK{$3L|~@hrN)Nn|A8%?Ntc(TJn-HHHp*zs3G(fr;uB1D z{-8Rc*9!Mx#Vw1dhdMQ0cF@}nkiHJRtKRObOX%^~50~a&pJ7{FQ&5-992p2|*9lUk zJ;BHwdv=xW*^W%>#vSMVwkJO*MVC0TI&KYv(SkbarrPQ{(%9gNOYY(fy}X+i>NnnR z<{$`lN91|GE7Dp+6nWh4`5;9VFnk#;TaRBimwpOB5gvIek>ckaEY!~m+@eRW3CY5)pZg?#DW>~J zTj^=zLOs{9V0jx9!*Pd(OLn*Re}4l)sTlEdlM<(9`vBGy^~|FCz?~mH%w!0(qH77a zF$K=Vk8SfggRT-Hx22?Kb>4q!P8skwH+Jc3bIxK8n&K+^I^!$W&nMSX2EXZjiWmpP zP^Xi~#k%IJE2y~}*Dn;Mt3eI#Q9LfL1hH^=b8$acyMcaLEEC~)KePr-v!`$^yQZ9D zPcYg{y7#-MDH%S3qsQ>dHH$4X^xb8O}h1-K36h$=&#UaGYMV zYuq6$Q_R?2&+s5=j|^h|w(`ZowLNXf^Rhb+n-eejrtE}&f8_IZiwP!Kw>eH%BpdLmpLC4zHPAYay7JFEJ37Qp)d!M+KUMU!7dPmUGwb@H1NlQL{s|8ap z8PxrCaM?BwE{&zlu|wEQn&UsH3vLN(Q**zkZxPf?+8V|i|Ff?NHcO^8ZlM2&jrH}JF^7aH^!xX} z+PPo!SH*U8%eyEcNB5g#yZfLr47G8Kip?@m-kQ~D`UpsS#O_V+v5wF>Ep$Z~Aj_3( zjB40hzvoPiA^7!V^-tdNqv@>SMxi$<2()N3TjBr9KUpq#7+iBLq zJcCb7C!L7ygc;1V6sHD%^?s|>c*uQsETml{eG56>)rfUaZQuwK}71 zFydnU@bk+AA#08nS9?5@b1fN$e=&Jh@zg}!{lVdJSw-{ROysI??zH9n&zyYGNPfjm zkG1?q(D7UHCL|ejUiP7I02TjR)xD9=xrH(+E+6D`s5mi=JCsvRZVvtKJ$Xytc{d=V z^-z@6s+ZR_s~g{!snooU*_4MVIkd<6_SES2J<19X_h)wd;TFIH-5OwsFXE77os>0^ zgnZb<3f6q{)WIzQzW%~9DiakQcesq|t$OOQ20=2`UiV`yFKE}7CD}mp2+avPn;&?3 z+^?rjDC{Q2+oUz=N3}IVWE388x!PnsZ1kUs@Vr6YCD88#DHOTX$&TaV)cwU1nY}iuCLYEtl?FD{G%%;W$OQ z1mu|pR-fwnF~IB7aRnpb@g%(|-HPa3Lu5U||E9Bem`7!MeC}bd#@kdrBDy7D>ac3# zScAJ=-K}kY1I+`k2I#zqU_ZIgyU+&txa`ABSrdO9v<{s;GQKn1-^cCbV_^Qym8j%1 z+tS#)k|j?&_~IV&+l7Vol!OZlKKJFzm&X3AlAf{Z@jHyUYEVk<0+sJ)t33Ad^&-8D zBx>zzM$)A>W+aC|al&_IHKPseR(T4cy6qze?erG-qpW(!&{(sZb54#EA7`I@;Ev5S z>TuQ`W=E_%BXcXst`QMbb~?t+qi!>No>DcwB#Zr&Opy188*e|cLuJL6BW$+!&mbGZ z`IIJNjE^56;|>vif4Plh96py_=aRDF*&qLNbB-sZ;G$$1-VZug;HO5D8+}8&rw~Zf z$K~Awo~H~)JYEW|m}_;B0X<87EgcfAZO7#6h4(Q4S>WE6)Q|_;Cvv-*3kKwB9JY-R z@W^!;dy0Y|HF}N(L6|Bv4A=9sQ`8I}T?p0rQ8!P893#+NLh4lIw){KTCm9y2ixjSM z?OUeA7-7yFklX*hB{^!npIwJ+(;}s>og1tRE%Vt_MWS%`Ti2(Dx$)(Il6;@Hd!;q9 zFSI5(i->a6fxQjGhT0vUBA|LDGW@A;(E7ok^hi456C35()RbfjVgy#ac>3g!^{|{G z8sb*JO6NU0>Ex2{2Q^d~0bj5C-ejiv96FUF_H!r5ZQ3P4vA*nimU!D?>AseUlG4H9 zH}Bq?yEP1_UZmC}sBHWI^}y5S`L!R!7Ioix@=>&V{{XBwrZzCp-~Sf7mrhS!{lS~Q zZM-~u4kM>4V?XEoycxI}y@}n;o+G*dWT}yK5~{g?zqIL#Iq3V|o2B(;`J=zXBBk z7PhR)3dHe`zw04Gi$CedSmqx$+QhNAtj8qEML_NY>9l|F3jAkTKeV4dQvE_F_?(X?L<2%WPSb)(aXy7c5^ zYi7j4*}J@JO2p{(Ao`UxlQfrbyBDoEuj}<$hjx27`&Q`mn*2*jc#*Nw{Uh;;=<4%l z#xsL3U*%Ze_%pLSYHN50oc*n2QyL`OY(#kE~NYA2@gV1F+z4|awF;g@SNwq_4+2X&no8q5NvHC3lQt58UQ)gmBHMcDtI7?5jJNGh!ov^{jK=EpFy+NXS1EVh39~nr!Ms? zcNj5Fww|RI%0eybjs?W2o;g5HO@M7eEiIDx*Rv;;?we8%1k%oGTf_S zl0p!;YMF&q-{PEQB1O8RyeZ^2)4e{*2>jNG<@2qoTtE`Kz)J5h z<0@4{bdV1Wa{Ea?o;As^d`@evJRS$hQVKmB8^b|T7_k_ab$|E3zacnTa`56!k~3zT zSsJ@3+Hf$-8G2#qBnDlgP4-PZP$raWhNbq%4D#DD&VDnCyfhFTLH*Ll`!-U#qquT? zbDj3YF>X5dd70UoeV}H-10Ik4+lM~@7y20V37Y$RmJ(!ffj_Uv;huw5ACO^$NP4ZV zx`rNjX5m2tFLsASb~$zL6}Hn;hzrjSJ5ihE5E54H2)Z) zz>`)Fn<@efskZsM*GXreFl439>mzhi=b}2pt=HUdQW88oZ~6Wzh3e-Kec8q!BsaIU zgK3M%t=+zQa*lghLZnqHpIa{GUY1H9p=WJ{ag-0@{&nt~EoV%ak_+y0eapJtg(8KA}y$ULcq{tK>vlk+}N&Gvy+v^q^2O;wU-@{G_wCi|mxKtnD2} zKboc^d0A;RtY+M%1~I%p_=|0<1D(4$1^_IcF+NUR5w}+Oz`yqvInQ*3FECW&NqTlR zj$7wx_^-e#FKF(@sjmT06=)TwV}gQ41vkfR^So<8MPlC@tn9C{DW3Apif7QXP#5gF ztnU*Vb1A$_!*li$Io&hd#SHx?pp#YK{N}8(R4r*dR90Q;+H{J+@#xaH zrQb0K_)gc(Z$FeRTnMM=b$S#YB$%4c?}fMIU-nKc_I3xT_6}ijB{*+g5ZXXU65ca9 zKep!nD8SUt*kvu@g(h`YKSwnK>XtDP=(A=iv&mP++_P~*>%wwhumjrN_w7E866pC{ zv^EvODrMqH;pdPmH0~23$FN;0HN4OX%3|Ecu3|XwYx^_S);7^5Y;IoKMc&$@VZ%mR zL>FG_p$(=v$`{EkraGxPuwSXpx4w2}%WgXVTDqyHJDR0d7-I#vaiZ-k$?#4d$MY4Z*M9*yYhys0{CKY14+k6B zUM{YP-s#1u^8{&asBmHN!{ZQ&c7kFV8bho07YP7r(m!70Qg2IrZq5;urpCIz*ttY7 zu?VdRhCZh?Osd_8#2-p0+Sw{-KdiR14ukhJepFf%`&h^kgF--OKVQPwLmy9iwlmoP zw4z~!-MyYZ{nT=E!t4mBRsKWCptrBriaTRs>pte;`TeI;)vqfy))WoUGIp8?MUvvy z)by|yB{D1)fCrG*lmp}*(-$R@V|Yz<*Y(YiLG&=zaC7l@>;Z)_HXg?P;PjBmT0iRd z*DUI>mD+vJbn7}*NXNRs6~>9QcaHiq+)=|9omb~kJMS1x@jV((6B1N+#?tCZ_kO-u zhHH=H*m4u9%>d=`S%exA(Up_&Wy1j@9J2y%6le8fUlRgZg^Eo2`PlW_)KNvegMHvY z-mB~qGJxo@xo+}#f8I84t$TOxuxP}SS4Y>cd06evy{gIA?}3CBTf?H}PQEHS-K^Ze z$hR=0TdhlOYVk>n$WJ4~oLLmx6Qc{!9%F|2a{5q2lj6RbACi~D zKXB8J@YN$ijF(aC%#n>XAv9~5+QW9!^R-gGmZGiD{eAx2G)bf|_?ZsNF2WtK1-=7; z1_y@GgWv-9OzSg5)f#W>?MJVb24mdsnZ;#bYon3#Tw+erGg%cFX&cU#2^HtG?`q!? z9@w#!lr*?629#KpKi0E;|W8OU3!JEt=okguyo(TOatanvlVuP3`W$6z;8= z<$x$IOs%&VH2!1gboIjs>aC-4E7KN!j#+wP`aIwmIK}7`-8;2qdU2*ks7q-X(P8KD zbvx9$73vJZcP4GWdY$B7-pt%h&@X3aEA4im!M_u*ECKwFj8UK8qmwz8H!{rDQr>?^ z`XO)*h#IV5y_k~XhiVxYnM*fqFpa{_SV=JJq?9j+oSN=QO*Ul7p4=C{=fhXzZno+c z(KuYuHB@zXwcL<*%RuxK&U~bLl6GmWB&{pAUp@-pau?H*Q;P+?FYR2v<9} z2!vMWYoHwu7|2aFZ8gd^umbD_WJhwPQLFj_&_lsK5oF_yOsXyd3^Y6Pux?907pVS@ zttNlZ{=GHC(nYPQT{ye6=!r?>iqP0*LAXz`*TlKgQ+;y?p~&{8lHd$OzHCsj9zYD~ zT#0~bqWY)1lzhNj5Ry52?bW@tT*XHJ?SZ=9I>2hYdbsB+@}T!K2M)!YEcT6l!0r@L z*{j-0JCse_8uNiZwBmHqrIZe!7(Mvdt;}Fy@2Qcw+V3!T=$i*xSsm9bVF-@bi1^t| z6GORQ^W2zPubA;(yzX{Bkp3~jY#)6RK2{ks1RBu-p!xay?%bEG_(h-;)(dE4y&i*c zQd!96wo_;3Wh_LhCZGa#iiIaMOu(^N((<=>!#N8pi6}vZ}BqG!g9yXjjF2 z_~~uNg|*zOT?>UG_|lkL%7FoEaxKD3Xff8rFeh^;=25gPc9-Y9txAUEon;ix^M^iI4;gBhmp1giTc_yO&xq@nOxskoA0Y)yy^>xvMtr6~}hwsF@MsHr)^!Km%{OQ&0 zmaakf51>ykrCzTpf0*kHfP$~fnUEapJ)_>an7qH`^tp=Y^&ks(bl@2v5_$?48A;rBz${?ExUF#I{IYf$xy7xPz%3 zPXNVsdP_@~;)dgIGA#Y5n1hx=-qc@sSaR7f-)2Wend6R^c3SR*k}HEQOhjc`#RR3g zg)0Q8=&L_i7@zP0$|cpyvp@+q9j#WNtGN;>vzGi|u}&0UN9E4)6!C5Iqz zpJfwpIdy`()_?%t)Kl-hc#34_Ys%LO<^Ebfuqs~L{~VcZ`(xoKlrOGiUm^$qj8Ufa z*80_If2C5fCka`db z($vZD@zoSJo%rr1g9;Es=uTgq)T8I@4MBGhr%rG*D5vW`8y}4nq9o>y2OW`z7yn~_8uGbpJrG$@Ys0}<-3E)jG{d0Bds_POyAz3 ze#c#VpGWb+faBW-wD{1`ch!MXz>X58HvT1JcBODV+hlZ2OF2e_;oa?Upu^(H*TGx? z)wlb0L=Zd!+)_%YEYK ztaWzvd0{*0!$U;~bP(S8*#M|KhiWu^C2ex4=UMToNe>Zo#a~3|?(msg3JfHu;3U98 z3vu18^g$6+ALr;j(bvaG?Nv9?zR`rtHj_fuQ!g?`=Z)*oH^$Qk8vOXUh8zdyg$vw% z6x?nZ`RYDA=}RCh?RcQ2xlc+0~FiB(g1a+5%D&_os*zq1+1XcV(*E z{3?`qYubxD{6XI&&>->vB?P~10DoIvQj$-;%hZb2s(xwXX8Somj)}cs)L(&`Xs5e+ z$Hd|%>1$i+pzDH*nx~U>WN@3Ac_3pW%O}Y5(&j55v`DfVl>B@#3yvJ#@Q2tbMXMzf zb7F-ex*RhuC&^$NCj!+A4d80mQjvlHZ^}L+#;`HsJoWcEVd>pLplYEZ>|qrivs;-Y zCl|N{N{9Nab#3<$J_c)SP|NN3#NON^Nj-rf?_@w<(VRuN!g}6-de#;z>gxQnToO3m z^>E3#2leWAh+Vu;(@+Dw0C(SfX_NvRnSCJhO^2R?t=AR^{s)pykve!GtGYHA#mfMD zASuPWN_`6*@T}-l0}W${j7&&5L&nv;I!Aqd9w^3s>uSMbU-4{?Q)Y znI^n*1;&bWu0c~&`f%Kv_ewpnaPd$xsU`ii7qKjGA9!YCpuDDqSpzt%PW7Ho z<$WJMO&`&EbH=1iVdlEXD*}4)U{G~>G1XL46&memDf0BNGO9!-*?FbF)IKb9w33A_ zphc>hXcQLzoiLTKB_-Na5jEHaG(vnPE0^?f^*9IH@8iwLXA^Vf*KONeLQ)M#A2%1T z2O%n4=*e#M>AdyVR44$Vx2u+^jN>ehJd?smb#Fakr zru$xOxA`00$Xx`1nHN_ zz}`9_XdXF5x^s#|Pu3XNm{ny`fR1mx3Qxn~roiKmJYW~prMt5t?X*X4g6%^Xn1&o_ z!oa=GSd=3)a;jgg{^j4eBb~5qK%gex#w#ZFYza=e`BJ^~sM0HO*S<$X*rg2RsL?)? zaVvINE#F9+8RaIws*ks$<7!__+r~|+y_kOJ!*@X?IpUJIaHZyW@D{fsld zxNM-4DpyoPo@bAs>LCeprnnLzEUk&u!tuRFUW;(4LJq>|0^dfzXcEEo=2{JvB>{0mji?tYymxQm)t`! zTAkH@vW!vFuv*8nAuhR?@;<+vpPr3htVoH?=0NWrX59l0ROoAwo}HB)BQ;k;zD&rU zP%_(0<5gLR6H8Ow^gevL*RX25nMzX6%^eYnfB^A5>yc)7#_>I)@)A0w%HqV7fI>uP z6n}$ZUJ;iG_A8H2SM|9`7x30aA@?83OfD5|z4b>8Ue7THk!~=V9qG%cX3&{HmDGoB~*rQP*~oWrGCGhMNE*ROw@s^-ELxYn5HTd;503kSC5w64i6m0K9Yn0{a_Azdk~w0 zysc`+?Q-`2;d!9!a-b?T$R)&FeZY_~BMvz8RX*gda}cH6kY=_m3f4?jeN&M{M9UCj z4bI!{xF{MBnzCtxn zvH%H;v1}=b3|CtHPs5dIhKSZ2N4M>?g2#gI-yh_2Kpp^&oM2J>IX zMD+R}R;v+UGdu~+KTK5{29vlmLh8>X%U2QMN435$I?s*l&RkyBE%(&=$ak51c{xkk zThap&9&Qkd9P{*g!AN#i=B=SL-5YNjDAD;(zcz|3MSAdAC@B7*{oRS~aScdfx;GXr z@Kc4&w*fOYnX>m&CJ2%Ft2Ea}rwHtLKot#}mEBCooL)6&_wc_XX|N-!v1RYx5^2BD zP#(GvUDe}s69aU(L(91OCs(r$wvel8F6%?3UnX}X1pY(T^li3~&to1=F1;W7&j6W9 zKmPHq-B81K#Dkj}R4b;H4Lr}ACnshZD7b*j z51sX6W~9>L)X(UHdnLYQ33eHe9LdIGMh4!c=Kn+S&6dk5p9MpQ@j~3!b3p!{O_(oj zsSI#^Nu}T0dvTRk7M2TQ=G$|xL%bTR!z(TnX^%cW?Dll>IN&`gv3^}OM)z_lKdnzM zi(Po_raH23W+%6dWqYduRs-OeD+Ef=aP_#b9Yy|2#3?w-2{v~YoY9_>q4q2Fa* z>YS2v_gGm~Ro$G%4PdRwK}O}>H+uKCt8FPHl!y>a^Kr{YWrb(NJk+KL70`yr*eO4p zp!0OkGEEPI#U!W#{gs=lx}Iuw%O|KAJ7xdf5`s#7*75lMw&dfpa3HpFT!Hon8Vv1y z0#5T-Q=QcGk?JF@j#E1JY9;&ms(y!B3fXFLw)9b8!h_^epjDD@eVwW8+5)6p0_?8& zTUln!(!a(Vyq#r;^jNPFJh;}?hhY=5+SurZm;qU~zik+pcQSuQuSPJgg3ltFnX|87iau!(%b~=6#{pDR=|?(EHn^*(8yA>lHe;6Z3^Rus@f!T&#A`Pt=|iW zQyxTLL1@`gwwlYcCR9`SmA9d~p1fVday0{4<}f_5XYnV?j2YKKsp3DZB!{w)DESL? z_C$f=o}h_Nb0F~{5{nlfFgCRad`Ebs-`_ma9*{A_{toF6%3`x$X1!jS0<^Ft z6Ym0+&DWoN-5s&z#M|$Y`OV&|yd1CRkL^dbxO8+J&hLCOb`6_s}Ew$lF0>+Hj-oPex&-=nlQn%*W93 zk-8I|5Y@8lzl21Aqd8mf8_jNUZkB@eQsF57-cD-#>fhy3s z+BOWh!9ap19B6+nxrBU&RJcNQ)%Bmk&!g0IfaF_SuY23nz3v5YXbKS@wdHIY$--GN z*|XA?^PE;ttr}OBho81=qG2~ z^y|acDs28QiAmo;K!cx{injE&`uwA5?^DqG;?uKI%xY#p{%mEFZ3!3f^WFCU8)a#) zEdfZ^$2gzIP$Zz2UQi+hRC$&oRJ1acLVK$(Ob#iRjJ8b$N+76*FLo)!ZztBH>cTk& zxl~Pg%YOfk8KT}BgJkczl`LE5wkc8y8gj&jyl>*X5E_%P#$U zYn9HDj5her|MFYnPsELXdbofzmyE5!4{e!kAatjO%Xe^!>_A^^NzEeAkmroZMxZ`; zJ~XvhK-~qJd9#SYgBxmv*bN(s+;Yelo`EvWjxSWz5+*vPY67--Vp;w@=3!*xfvjTIi|9E` zAY;NfD>bp^T5Tq`=s)qJJia->;xsYb+avrt4JJ1XCx!Q{D30=7e}jLr4NgLM=Q4Og zrKBszTk4sU*f_6`04+gtRI^EpN;G*PkG39r(JePu1Xp9F@vq+Ln7JN|tofkY< zhb)y(KTZViKr$yuKy!nlr%^~pVLYz+%yB9z;7#oI9x-J^?xoj-kH?@NMmm_@EQ`wa zO@!slxhkdbt=P@+>%~LwhDL_SE_%F49e;_rh+$bT(Q*uPe5FJ6qHA^=9t4e zr$jCQz3@Gc^m@98d>3XL5QNE0z+CFeoF<}5fLg%2Jb20X7sLd@QOMs&G?c74m#L?(A0DeUh&V_UEjQP4W2Q<-PP2BA-;^ zHAl+?N6!%+OP=46yNx8Oblop4N{90h!2#|KFJZ57d@4_Z)!V6W;W|ND&`(^R8 z*Qd6CX04x93A=x#xnLD`h=Ra@g}D`mM(f~*5HqI#OeE)I{$NS&fzy|h(ZIwVV3`GO z@;q)JC~3{!WaG|KqOd32_$&n!7~Ky?mzGN(ytl0y!FZ;&s)`d`VGD6jTKqM#TV{RJ zo4!GdxrYE1)<&-P$~rx(kmY%8K94W|^72%;Ke^CX=3j>jW9j4VwR0d%lj`INQSAZi zNnNhR4c)d|1TqG09&0(OEB!beWW<_1aP0{st2?W#GGPUrHDZ!M6F!(z`}M{TMQ_Kc zfMQN@;Ld5Nzn^E?((civs<_l~bw_{w^vF*XG2=fG0gfWz+Fo>J%?B>5J1kF~0j$D@ zhSehfC!S=Jd!0Z%LWy1Eon2#GKh46yWv#K?TO0M+w)jwuqGsR&IGSIC15=^1!A*W# zRF%^)9q9+H{=MJkm)eaN$g4RBYxyr8eTNHqqF8Pi>!F35)WiU|#qHR?oa~LDHaSV% z_~lnH-PLz|JBrT;!JxC8=kma1lJ|m!Ao+e^r=G*O^sih?%A1ft(lPjyQbSxPxYF(9{}68^D$AKAEb8r6shl zOyv!jf=k2}nWS?QQ1|4b|4E0~YPG8NOKrD*D!qSlcI_x8SV8J+ne<>PeL-$s+jcEi zLbM()3g(2)pok6vTM!w;fJI+)uFfRX`UF5o4wkPu0|29RxZ5m)o z{xHU@`l@9O<#>&2Q4vd8cEf-d%s%$*h}@EM9*nz<&#(<6ptc>(UvyDg!=~<-zo(@V zc0BYYrB=_*xMX0x82{yYd;}y~ZqMQ5#+CG|CU z^~m3&7$c~kJ1x4#17rSGc-w4~=&KFJK1I1yS&NKo;>r(1Y=--|-g~|ggY&8;NB$O2 zrBzmc?Qy&a4B$mj`TYAkemNw*No|<~yVqTdHR#i*JQ>aRUEL_E4WMY$zZ3-uj2aHD|yidayH zp$1KvsjVxW8c%(tXIe&oa@nhEHxrxTkx=v#asMnJ}8TMueJDDPbD^i`L6 z_i*a%DNd2qddZR3sk>3OZyOZ{GzDRLC**cd)*njdDK2?$Z~)^&I&#}Tnx3wa|gjs2`9i~30&OX^d{^-8~D0G2X7u~I!IU+MC6OIiTx!Ym%6x&T_ z>O}<9Y7&LjzuEu=dw9>Y@XC%0uuZ(%sbGpQwlZ zDf@3UqDoI5%n}MB!QKt`@EGf(6)e_VKYD{fMgAHX`w;XzB(Ozc-N3}!*A7!rhhXLN zT7gbDRRY7qdKgd6w7nGNJ};k9BoGPOAoINb^T7G*89Mi>mC2Nh1Y*?sX~J{K;j{*X ztisnhM}z4nkADp$=0*ihS=~6A;f~T7Fm35s8|o7wnlz`*X-mCY^BZi`J{$n=f&Y_O zYU>ox>HkETe7gfkH=F_Ldg*X(@qwYO_Uj1Uo*Zdd%W{v5`7f#bfsE#JzEcl;E3WUz z*Gl;eZEi*2b@oNB0UiKuYabbd^Q{AM{>%0qAOsx13j{xI(g$n?&^QM)`(5+klCWZHn=Ks6>~#_9f$%;*tvPHhDGP)F z1^vq>Io?U^PvvxJi9qK;?zfc)adk@uKx;_(d>KX%vM*59#Fm%DNyXs`utbFB9LltT zwcRe@H@AuvTi5VQioh~1L*YlV2!74{S`l9OXLu>dfzUDmy_s#sfwF z51@Ar|(DZMP6XWCKElCB&7j-x2G97hYPxSS-;(yd^@fGT<_YlYN@5-H*M0aHV1i;LRy(H1v9hXjKDlz-`b+dM=2v8E&Ht!UiL)eUE}SQ1PNp7NKoIV5pNoL(V{}3k zH0i2=Y>ce;%ZE3-V|$H_{Nig9Y5SU!bEwJg+9idL7(WYwb)NbYTzzvFID^j6ZZN^5 zKojy&0chg2n`8#27wI_5XmW^x!cLWv54#cw-J|bI5#a%}u>|L2asa7-ka;v=uH!QM z^U|Z%gWOCD!Arm!tSJZ;R$(MCB*g_v`!lPqW)#-MI8+6UE8jWL#|Zx!K{sH9O9j`V z%OmE{|7{G6GXZ0|`OBAJ7`Z<$#WM~EG3ENM8QPQoLb(4qKJDvFj$S&<#sdcVXl=SF zfo+sLf;8O*7$6$oBT@wa*)45|-D435-+nw?I-G_KxT$p`Eyv=Ki>}UelL}%|BCdEnFuI0rHdErexbtt z0RC#rW&RGtX|f`dgShl(jsNg8s8a4PCEs(Cz|3mt-```P(@MXDS5(5{imZ|ZK^S1l zk#5d5v>j{o)Bxjsm`|YmKLbVNpYBQj(+_RFBO7EcsijkU337BoP4}d7;7<4@n>W98 z+!F;H&)Jy=a`=S!sb6n>4L;{o9kKIeP&aUU$GK8%pm;+Ka8ui?2Y=R_zHrapcNy?x zk}#ks!_UjWz2sb9mVXksb@s>d*T9XaK(Szry=V9z$v;=vusnG$*VoE*(V#ORP@dPY@{g zz;Bn6_mbx`3Tl9j-+e~~Uq;7n2ky8B9>)Q!owZ`@tYgDtTlx0`_d73r_326F?ArTF zwaY?*3qU}NDYZQM0(4JGYyVffH2?bTzOUuVy4|E$};JVapsJ;!!X0C5^I-etL! zeKrEPd6o^dO)dWaap1ufz8d!KqFwo~_`m#02DW0R|NJe#Jbmq6hwroMH`}o%*6;wA zX9H_KaE6+%V0ozJ*7I#L6_ZWouWnu&_6Z@jz|mVOy9|WE$sNEQ>|NoVh5Ip9((_p>2D6-p{-ki z$Gb3I&jkh7vGpmFZ`=@eJ}XjmWx9Bfm>?rck|u6+5YVbJ7^Ck zSm`FgW1IJCL_W33NSC^`Sj9tsN9o-tpb-kjz_@$~6s=ws3mmuf2l?|%W#hi9r8mxH z{E3bI_~GHY-=Jtpc@K;fwGZE(DwO!A*t}%_B@NvEs(<}0_%Ird$&-r(_NW_$tpe3Q z(m+EPszFJf{p9yzfqln~=KXsLeSK@U-`OJ}b84L&C&G~$nu{RMk-#h-b`^9m#`U7x-2MvpXb7+@OL zu6ygE!r#2zWee2run%}ZhO_I^;J*R?PjEl}{mZ%LPej$V+cmYvxF>IYb{@F-xAfU+ z$&V2I_jPTyX=dRJK(a3~8n z?h*R3>iyKIF-?y=|JGlV|8V%nYU%Sg{$K7fl|Dab`d$s*pT9JPKm1SqS2Gv5y&l-! zTdiSt`p@skRi-!nR{|T|K0sp^*qP|x+E}#xlJ=zL&?U=f8XoU44)K2JtJM*|9=P~< p)vd6%(Uyxr1;!{q4}q=2.3.0", + }, + "mlflow-skinny": { + "mlflow-skinny>=2.3.0", + }, "mode": {"requests", "tenacity>=8.0.1"} | sqllineage_lib, "mongodb": {"pymongo[srv]>=3.11", "packaging"}, "mssql": sql_common | {"sqlalchemy-pytds>=0.3"}, @@ -542,6 +548,7 @@ def get_long_description(): "lookml = datahub.ingestion.source.looker.lookml_source:LookMLSource", "datahub-lineage-file = datahub.ingestion.source.metadata.lineage:LineageFileSource", "datahub-business-glossary = datahub.ingestion.source.metadata.business_glossary:BusinessGlossaryFileSource", + "mlflow = datahub.ingestion.source.mlflow:MLflowSource", "mode = datahub.ingestion.source.mode:ModeSource", "mongodb = datahub.ingestion.source.mongodb:MongoDBSource", "mssql = datahub.ingestion.source.sql.mssql:SQLServerSource", diff --git a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py new file mode 100644 index 0000000000000..2884269465fe8 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py @@ -0,0 +1,330 @@ +from dataclasses import dataclass +from typing import Any, Callable, Iterable, Optional, TypeVar, Union + +from mlflow import MlflowClient +from mlflow.entities import Run +from mlflow.entities.model_registry import ModelVersion, RegisteredModel +from mlflow.store.entities import PagedList +from pydantic.fields import Field + +import datahub.emitter.mce_builder as builder +from datahub.configuration.common import ConfigModel +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.api.common import PipelineContext, WorkUnit +from datahub.ingestion.api.decorators import ( + SupportStatus, + capability, + config_class, + platform_name, + support_status, +) +from datahub.ingestion.api.source import Source, SourceCapability, SourceReport +from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.metadata.schema_classes import ( + GlobalTagsClass, + MLHyperParamClass, + MLMetricClass, + MLModelGroupPropertiesClass, + MLModelPropertiesClass, + TagAssociationClass, + TagPropertiesClass, + VersionTagClass, + _Aspect, +) + +T = TypeVar("T") + + +class MLflowConfig(ConfigModel): + tracking_uri: Optional[str] = Field( + default=None, + description="Tracking server URI. If not set, an MLflow default tracking_uri is used (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)", + ) + registry_uri: Optional[str] = Field( + default=None, + description="Registry server URI", + ) + model_name_separator: str = Field( + default="_", + description="A string which separates model name from its version (e.g. model_1 or model-1)", + ) + env: str = Field( + default=builder.DEFAULT_ENV, + description="Environment to use in namespace when constructing URNs", + ) + + +@dataclass +class MLflowRegisteredModelStageInfo: + name: str + description: str + color_hex: str + + +@platform_name("MLflow") +@config_class(MLflowConfig) +@support_status(SupportStatus.TESTING) +@capability( + SourceCapability.DESCRIPTIONS, + "Extract descriptions for MLflow Registered Models and Model Versions", +) +@capability(SourceCapability.TAGS, "Extract tags for MLflow Model Stages") +class MLflowSource(Source): + """ + ### Concept Mapping + + This ingestion source maps the following MLflow Concepts to DataHub Concepts: + + | Source Concept | DataHub Concept | Notes | + |:---------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------:|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| + | [`Registered Model`](https://mlflow.org/docs/latest/model-registry.html#concepts) | [`MlModelGroup`](https://datahubproject.io/docs/generated/metamodel/entities/mlmodelgroup/) | The name of a Model Group is the same as a Registered Model's name (e.g. my_mlflow_model) | + | [`Model Version`](https://mlflow.org/docs/latest/model-registry.html#concepts) | [`MlModel`](https://datahubproject.io/docs/generated/metamodel/entities/mlmodel/) | The name of a Model is `{registered_model_name}{model_name_separator}{model_version}` (e.g. my_mlflow_model_1 for Registered Model named my_mlflow_model and Version 1, my_mlflow_model_2, etc.) | + | [`Model Stage`](https://mlflow.org/docs/latest/model-registry.html#concepts) | [`Tag`](https://datahubproject.io/docs/generated/metamodel/entities/tag/) | The mapping between Model Stages and generated Tags is the following:
- Production: mlflow_production
- Staging: mlflow_staging
- Archived: mlflow_archived
- None: mlflow_none | + + ### Lightweight MLflow Support + + Besides a classic `mlflow` package, this plugin supports [`mlflow-skinny`](https://mlflow.org/docs/latest/quickstart_drilldown.html#python-library-options), which is a lightweight version of mlflow. To use it just replace `mlflow` with `mlflow-skinny` in the installation section below. + """ + + platform = "mlflow" + registered_model_stages_info = ( + MLflowRegisteredModelStageInfo( + name="Production", + description="Production Stage for an ML model in MLflow Model Registry", + color_hex="#308613", + ), + MLflowRegisteredModelStageInfo( + name="Staging", + description="Staging Stage for an ML model in MLflow Model Registry", + color_hex="#FACB66", + ), + MLflowRegisteredModelStageInfo( + name="Archived", + description="Archived Stage for an ML model in MLflow Model Registry", + color_hex="#5D7283", + ), + MLflowRegisteredModelStageInfo( + name="None", + description="None Stage for an ML model in MLflow Model Registry", + color_hex="#F2F4F5", + ), + ) + + def __init__(self, ctx: PipelineContext, config: MLflowConfig): + super().__init__(ctx) + self.config = config + self.report = SourceReport() + self.client = MlflowClient( + tracking_uri=self.config.tracking_uri, + registry_uri=self.config.registry_uri, + ) + + def get_report(self) -> SourceReport: + return self.report + + def get_workunits(self) -> Iterable[WorkUnit]: + yield from self._get_tags_workunits() + yield from self._get_ml_model_workunits() + + def _get_tags_workunits(self) -> Iterable[WorkUnit]: + """ + Create tags for each Stage in MLflow Model Registry. + """ + for stage_info in self.registered_model_stages_info: + tag_urn = self._make_stage_tag_urn(stage_info.name) + tag_properties = TagPropertiesClass( + name=self._make_stage_tag_name(stage_info.name), + description=stage_info.description, + colorHex=stage_info.color_hex, + ) + wu = self._create_workunit(urn=tag_urn, aspect=tag_properties) + yield wu + + def _make_stage_tag_urn(self, stage_name: str) -> str: + tag_name = self._make_stage_tag_name(stage_name) + tag_urn = builder.make_tag_urn(tag_name) + return tag_urn + + def _make_stage_tag_name(self, stage_name: str) -> str: + return f"{self.platform}_{stage_name.lower()}" + + def _create_workunit(self, urn: str, aspect: _Aspect) -> MetadataWorkUnit: + """ + Utility to create an MCP workunit. + """ + mcp = MetadataChangeProposalWrapper(entityUrn=urn, aspect=aspect) + wu = MetadataWorkUnit(id=urn, mcp=mcp) + self.report.report_workunit(wu) + return wu + + def _get_ml_model_workunits(self) -> Iterable[WorkUnit]: + """ + Traverse each Registered Model in Model Registry and generate a corresponding workunit. + """ + registered_models = self._get_mlflow_registered_models() + for registered_model in registered_models: + yield self._get_ml_group_workunit(registered_model) + model_versions = self._get_mlflow_model_versions(registered_model) + for model_version in model_versions: + run = self._get_mlflow_run(model_version) + yield self._get_ml_model_properties_workunit( + registered_model=registered_model, + model_version=model_version, + run=run, + ) + yield self._get_global_tags_workunit(model_version=model_version) + + def _get_mlflow_registered_models(self) -> Iterable[RegisteredModel]: + """ + Get all Registered Models in MLflow Model Registry. + """ + registered_models: Iterable[ + RegisteredModel + ] = self._traverse_mlflow_search_func( + search_func=self.client.search_registered_models, + ) + return registered_models + + @staticmethod + def _traverse_mlflow_search_func( + search_func: Callable[..., PagedList[T]], + **kwargs: Any, + ) -> Iterable[T]: + """ + Utility to traverse an MLflow search_* functions which return PagedList. + """ + next_page_token = None + all_pages_where_traversed = False + while not all_pages_where_traversed: + paged_list = search_func(page_token=next_page_token, **kwargs) + yield from paged_list.to_list() + next_page_token = paged_list.token + if not next_page_token: + all_pages_where_traversed = True + + def _get_ml_group_workunit(self, registered_model: RegisteredModel) -> WorkUnit: + """ + Generate an MLModelGroup workunit for an MLflow Registered Model. + """ + ml_model_group_urn = self._make_ml_model_group_urn(registered_model) + ml_model_group_properties = MLModelGroupPropertiesClass( + customProperties=registered_model.tags, + description=registered_model.description, + createdAt=registered_model.creation_timestamp, + ) + wu = self._create_workunit( + urn=ml_model_group_urn, + aspect=ml_model_group_properties, + ) + return wu + + def _make_ml_model_group_urn(self, registered_model: RegisteredModel) -> str: + urn = builder.make_ml_model_group_urn( + platform=self.platform, + group_name=registered_model.name, + env=self.config.env, + ) + return urn + + def _get_mlflow_model_versions( + self, + registered_model: RegisteredModel, + ) -> Iterable[ModelVersion]: + """ + Get all Model Versions for each Registered Model. + """ + filter_string = f"name = '{registered_model.name}'" + model_versions: Iterable[ModelVersion] = self._traverse_mlflow_search_func( + search_func=self.client.search_model_versions, + filter_string=filter_string, + ) + return model_versions + + def _get_mlflow_run(self, model_version: ModelVersion) -> Union[None, Run]: + """ + Get a Run associated with a Model Version. Some MVs may exist without Run. + """ + if model_version.run_id: + run = self.client.get_run(model_version.run_id) + return run + else: + return None + + def _get_ml_model_properties_workunit( + self, + registered_model: RegisteredModel, + model_version: ModelVersion, + run: Union[None, Run], + ) -> WorkUnit: + """ + Generate an MLModel workunit for an MLflow Model Version. + Every Model Version is a DataHub MLModel entity associated with an MLModelGroup corresponding to a Registered Model. + If a model was registered without an associated Run then hyperparams and metrics are not available. + """ + ml_model_group_urn = self._make_ml_model_group_urn(registered_model) + ml_model_urn = self._make_ml_model_urn(model_version) + if run: + hyperparams = [ + MLHyperParamClass(name=k, value=str(v)) + for k, v in run.data.params.items() + ] + training_metrics = [ + MLMetricClass(name=k, value=str(v)) for k, v in run.data.metrics.items() + ] + else: + hyperparams = None + training_metrics = None + ml_model_properties = MLModelPropertiesClass( + customProperties=model_version.tags, + externalUrl=self._make_external_url(model_version), + description=model_version.description, + date=model_version.creation_timestamp, + version=VersionTagClass(versionTag=str(model_version.version)), + hyperParams=hyperparams, + trainingMetrics=training_metrics, + # mlflow tags are dicts, but datahub tags are lists. currently use only keys from mlflow tags + tags=list(model_version.tags.keys()), + groups=[ml_model_group_urn], + ) + wu = self._create_workunit(urn=ml_model_urn, aspect=ml_model_properties) + return wu + + def _make_ml_model_urn(self, model_version: ModelVersion) -> str: + urn = builder.make_ml_model_urn( + platform=self.platform, + model_name=f"{model_version.name}{self.config.model_name_separator}{model_version.version}", + env=self.config.env, + ) + return urn + + def _make_external_url(self, model_version: ModelVersion) -> Union[None, str]: + """ + Generate URL for a Model Version to MLflow UI. + """ + base_uri = self.client.tracking_uri + if base_uri.startswith("http"): + return f"{base_uri.rstrip('/')}/#/models/{model_version.name}/versions/{model_version.version}" + else: + return None + + def _get_global_tags_workunit(self, model_version: ModelVersion) -> WorkUnit: + """ + Associate a Model Version Stage with a corresponding tag. + """ + global_tags = GlobalTagsClass( + tags=[ + TagAssociationClass( + tag=self._make_stage_tag_urn(model_version.current_stage), + ), + ] + ) + wu = self._create_workunit( + urn=self._make_ml_model_urn(model_version), + aspect=global_tags, + ) + return wu + + @classmethod + def create(cls, config_dict: dict, ctx: PipelineContext) -> Source: + config = MLflowConfig.parse_obj(config_dict) + return cls(ctx, config) diff --git a/metadata-ingestion/tests/integration/mlflow/mlflow_mcps_golden.json b/metadata-ingestion/tests/integration/mlflow/mlflow_mcps_golden.json new file mode 100644 index 0000000000000..1464b2fc6d293 --- /dev/null +++ b/metadata-ingestion/tests/integration/mlflow/mlflow_mcps_golden.json @@ -0,0 +1,148 @@ +[ +{ + "entityType": "tag", + "entityUrn": "urn:li:tag:mlflow_production", + "changeType": "UPSERT", + "aspectName": "tagProperties", + "aspect": { + "json": { + "name": "mlflow_production", + "description": "Production Stage for an ML model in MLflow Model Registry", + "colorHex": "#308613" + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "tag", + "entityUrn": "urn:li:tag:mlflow_staging", + "changeType": "UPSERT", + "aspectName": "tagProperties", + "aspect": { + "json": { + "name": "mlflow_staging", + "description": "Staging Stage for an ML model in MLflow Model Registry", + "colorHex": "#FACB66" + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "tag", + "entityUrn": "urn:li:tag:mlflow_archived", + "changeType": "UPSERT", + "aspectName": "tagProperties", + "aspect": { + "json": { + "name": "mlflow_archived", + "description": "Archived Stage for an ML model in MLflow Model Registry", + "colorHex": "#5D7283" + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "tag", + "entityUrn": "urn:li:tag:mlflow_none", + "changeType": "UPSERT", + "aspectName": "tagProperties", + "aspect": { + "json": { + "name": "mlflow_none", + "description": "None Stage for an ML model in MLflow Model Registry", + "colorHex": "#F2F4F5" + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "mlModelGroup", + "entityUrn": "urn:li:mlModelGroup:(urn:li:dataPlatform:mlflow,test-model,PROD)", + "changeType": "UPSERT", + "aspectName": "mlModelGroupProperties", + "aspect": { + "json": { + "customProperties": { + "model_env": "test", + "model_id": "1" + }, + "description": "This a test registered model", + "createdAt": 1615443388097 + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "mlModel", + "entityUrn": "urn:li:mlModel:(urn:li:dataPlatform:mlflow,test-model_1,PROD)", + "changeType": "UPSERT", + "aspectName": "mlModelProperties", + "aspect": { + "json": { + "customProperties": { + "model_version_id": "1" + }, + "date": 1615443388097, + "version": { + "versionTag": "1" + }, + "hyperParams": [ + { + "name": "p", + "value": "1" + } + ], + "trainingMetrics": [ + { + "name": "m", + "value": "0.85" + } + ], + "tags": [ + "model_version_id" + ], + "groups": [ + "urn:li:mlModelGroup:(urn:li:dataPlatform:mlflow,test-model,PROD)" + ] + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "mlModel", + "entityUrn": "urn:li:mlModel:(urn:li:dataPlatform:mlflow,test-model_1,PROD)", + "changeType": "UPSERT", + "aspectName": "globalTags", + "aspect": { + "json": { + "tags": [ + { + "tag": "urn:li:tag:mlflow_archived" + } + ] + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +} +] \ No newline at end of file diff --git a/metadata-ingestion/tests/integration/mlflow/test_mlflow_source.py b/metadata-ingestion/tests/integration/mlflow/test_mlflow_source.py new file mode 100644 index 0000000000000..155199d5a04e9 --- /dev/null +++ b/metadata-ingestion/tests/integration/mlflow/test_mlflow_source.py @@ -0,0 +1,106 @@ +from pathlib import Path +from typing import Any, Dict, TypeVar + +import pytest +from mlflow import MlflowClient + +from datahub.ingestion.run.pipeline import Pipeline +from tests.test_helpers import mce_helpers + +T = TypeVar("T") + + +@pytest.fixture +def tracking_uri(tmp_path: Path) -> str: + return str(tmp_path / "mlruns") + + +@pytest.fixture +def sink_file_path(tmp_path: Path) -> str: + return str(tmp_path / "mlflow_source_mcps.json") + + +@pytest.fixture +def pipeline_config(tracking_uri: str, sink_file_path: str) -> Dict[str, Any]: + source_type = "mlflow" + return { + "run_id": "mlflow-source-test", + "source": { + "type": source_type, + "config": { + "tracking_uri": tracking_uri, + }, + }, + "sink": { + "type": "file", + "config": { + "filename": sink_file_path, + }, + }, + } + + +@pytest.fixture +def generate_mlflow_data(tracking_uri: str) -> None: + client = MlflowClient(tracking_uri=tracking_uri) + experiment_name = "test-experiment" + run_name = "test-run" + model_name = "test-model" + test_experiment_id = client.create_experiment(experiment_name) + test_run = client.create_run( + experiment_id=test_experiment_id, + run_name=run_name, + ) + client.log_param( + run_id=test_run.info.run_id, + key="p", + value=1, + ) + client.log_metric( + run_id=test_run.info.run_id, + key="m", + value=0.85, + ) + client.create_registered_model( + name=model_name, + tags=dict( + model_id=1, + model_env="test", + ), + description="This a test registered model", + ) + client.create_model_version( + name=model_name, + source="dummy_dir/dummy_file", + run_id=test_run.info.run_id, + tags=dict(model_version_id=1), + ) + client.transition_model_version_stage( + name=model_name, + version="1", + stage="Archived", + ) + + +def test_ingestion( + pytestconfig, + mock_time, + sink_file_path, + pipeline_config, + generate_mlflow_data, +): + print(f"MCPs file path: {sink_file_path}") + golden_file_path = ( + pytestconfig.rootpath / "tests/integration/mlflow/mlflow_mcps_golden.json" + ) + + pipeline = Pipeline.create(pipeline_config) + pipeline.run() + pipeline.pretty_print_summary() + pipeline.raise_from_status() + + mce_helpers.check_golden_file( + pytestconfig=pytestconfig, + output_path=sink_file_path, + golden_path=golden_file_path, + ) diff --git a/metadata-ingestion/tests/unit/test_mlflow_source.py b/metadata-ingestion/tests/unit/test_mlflow_source.py new file mode 100644 index 0000000000000..374816055b216 --- /dev/null +++ b/metadata-ingestion/tests/unit/test_mlflow_source.py @@ -0,0 +1,140 @@ +import datetime +from pathlib import Path +from typing import Any, TypeVar, Union + +import pytest +from mlflow import MlflowClient +from mlflow.entities.model_registry import RegisteredModel +from mlflow.entities.model_registry.model_version import ModelVersion +from mlflow.store.entities import PagedList + +from datahub.ingestion.api.common import PipelineContext +from datahub.ingestion.source.mlflow import MLflowConfig, MLflowSource + +T = TypeVar("T") + + +@pytest.fixture +def tracking_uri(tmp_path: Path) -> str: + return str(tmp_path / "mlruns") + + +@pytest.fixture +def source(tracking_uri: str) -> MLflowSource: + return MLflowSource( + ctx=PipelineContext(run_id="mlflow-source-test"), + config=MLflowConfig(tracking_uri=tracking_uri), + ) + + +@pytest.fixture +def registered_model(source: MLflowSource) -> RegisteredModel: + model_name = "abc" + return RegisteredModel(name=model_name) + + +@pytest.fixture +def model_version( + source: MLflowSource, + registered_model: RegisteredModel, +) -> ModelVersion: + version = "1" + return ModelVersion( + name=registered_model.name, + version=version, + creation_timestamp=datetime.datetime.now(), + ) + + +def dummy_search_func(page_token: Union[None, str], **kwargs: Any) -> PagedList[T]: + dummy_pages = dict( + page_1=PagedList(items=["a", "b"], token="page_2"), + page_2=PagedList(items=["c", "d"], token="page_3"), + page_3=PagedList(items=["e"], token=None), + ) + if page_token is None: + page_to_return = dummy_pages["page_1"] + else: + page_to_return = dummy_pages[page_token] + if kwargs.get("case", "") == "upper": + page_to_return = PagedList( + items=[e.upper() for e in page_to_return.to_list()], + token=page_to_return.token, + ) + return page_to_return + + +def test_stages(source): + mlflow_registered_model_stages = { + "Production", + "Staging", + "Archived", + None, + } + workunits = source._get_tags_workunits() + names = [wu.get_metadata()["metadata"].aspect.name for wu in workunits] + + assert len(names) == len(mlflow_registered_model_stages) + assert set(names) == { + "mlflow_" + str(stage).lower() for stage in mlflow_registered_model_stages + } + + +def test_config_model_name_separator(source, model_version): + name_version_sep = "+" + source.config.model_name_separator = name_version_sep + expected_model_name = ( + f"{model_version.name}{name_version_sep}{model_version.version}" + ) + expected_urn = f"urn:li:mlModel:(urn:li:dataPlatform:mlflow,{expected_model_name},{source.config.env})" + + urn = source._make_ml_model_urn(model_version) + + assert urn == expected_urn + + +def test_model_without_run(source, registered_model, model_version): + run = source._get_mlflow_run(model_version) + wu = source._get_ml_model_properties_workunit( + registered_model=registered_model, + model_version=model_version, + run=run, + ) + aspect = wu.get_metadata()["metadata"].aspect + + assert aspect.hyperParams is None + assert aspect.trainingMetrics is None + + +def test_traverse_mlflow_search_func(source): + expected_items = ["a", "b", "c", "d", "e"] + + items = list(source._traverse_mlflow_search_func(dummy_search_func)) + + assert items == expected_items + + +def test_traverse_mlflow_search_func_with_kwargs(source): + expected_items = ["A", "B", "C", "D", "E"] + + items = list(source._traverse_mlflow_search_func(dummy_search_func, case="upper")) + + assert items == expected_items + + +def test_make_external_link_local(source, model_version): + expected_url = None + + url = source._make_external_url(model_version) + + assert url == expected_url + + +def test_make_external_link_remote(source, model_version): + tracking_uri_remote = "https://dummy-mlflow-tracking-server.org" + source.client = MlflowClient(tracking_uri=tracking_uri_remote) + expected_url = f"{tracking_uri_remote}/#/models/{model_version.name}/versions/{model_version.version}" + + url = source._make_external_url(model_version) + + assert url == expected_url diff --git a/metadata-service/war/src/main/resources/boot/data_platforms.json b/metadata-service/war/src/main/resources/boot/data_platforms.json index fc285353f5005..fb5efb319b43f 100644 --- a/metadata-service/war/src/main/resources/boot/data_platforms.json +++ b/metadata-service/war/src/main/resources/boot/data_platforms.json @@ -346,6 +346,16 @@ "logoUrl": "/assets/platforms/sagemakerlogo.png" } }, + { + "urn": "urn:li:dataPlatform:mlflow", + "aspect": { + "datasetNameDelimiter": ".", + "name": "mlflow", + "displayName": "MLflow", + "type": "OTHERS", + "logoUrl": "/assets/platforms/mlflowlogo.png" + } + }, { "urn": "urn:li:dataPlatform:glue", "aspect": { From ea8c205b534ed4430ba19962f7e91cfa5170429b Mon Sep 17 00:00:00 2001 From: hariishaa Date: Wed, 6 Sep 2023 10:28:00 +0300 Subject: [PATCH 2/7] Make code review fixes --- .../docs/sources/mlflow/mlflow_pre.md | 9 ++++ metadata-ingestion/setup.py | 3 -- .../src/datahub/ingestion/source/mlflow.py | 45 +++++-------------- 3 files changed, 21 insertions(+), 36 deletions(-) create mode 100644 metadata-ingestion/docs/sources/mlflow/mlflow_pre.md diff --git a/metadata-ingestion/docs/sources/mlflow/mlflow_pre.md b/metadata-ingestion/docs/sources/mlflow/mlflow_pre.md new file mode 100644 index 0000000000000..fc499a7a3b2b8 --- /dev/null +++ b/metadata-ingestion/docs/sources/mlflow/mlflow_pre.md @@ -0,0 +1,9 @@ +### Concept Mapping + +This ingestion source maps the following MLflow Concepts to DataHub Concepts: + +| Source Concept | DataHub Concept | Notes | +|:---------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------:|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [`Registered Model`](https://mlflow.org/docs/latest/model-registry.html#concepts) | [`MlModelGroup`](https://datahubproject.io/docs/generated/metamodel/entities/mlmodelgroup/) | The name of a Model Group is the same as a Registered Model's name (e.g. my_mlflow_model) | +| [`Model Version`](https://mlflow.org/docs/latest/model-registry.html#concepts) | [`MlModel`](https://datahubproject.io/docs/generated/metamodel/entities/mlmodel/) | The name of a Model is `{registered_model_name}{model_name_separator}{model_version}` (e.g. my_mlflow_model_1 for Registered Model named my_mlflow_model and Version 1, my_mlflow_model_2, etc.) | +| [`Model Stage`](https://mlflow.org/docs/latest/model-registry.html#concepts) | [`Tag`](https://datahubproject.io/docs/generated/metamodel/entities/tag/) | The mapping between Model Stages and generated Tags is the following:
- Production: mlflow_production
- Staging: mlflow_staging
- Archived: mlflow_archived
- None: mlflow_none | diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index 43480e0a843f8..0d571b435c9fe 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -351,9 +351,6 @@ def get_long_description(): "lookml": looker_common, "metabase": {"requests"} | sqllineage_lib, "mlflow": { - "mlflow>=2.3.0", - }, - "mlflow-skinny": { "mlflow-skinny>=2.3.0", }, "mode": {"requests", "tenacity>=8.0.1"} | sqllineage_lib, diff --git a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py index 2884269465fe8..1f18700b23030 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py +++ b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py @@ -1,6 +1,6 @@ -from dataclasses import dataclass from typing import Any, Callable, Iterable, Optional, TypeVar, Union +from dataclasses import dataclass from mlflow import MlflowClient from mlflow.entities import Run from mlflow.entities.model_registry import ModelVersion, RegisteredModel @@ -8,7 +8,7 @@ from pydantic.fields import Field import datahub.emitter.mce_builder as builder -from datahub.configuration.common import ConfigModel +from datahub.configuration.source_common import EnvConfigMixin from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.common import PipelineContext, WorkUnit from datahub.ingestion.api.decorators import ( @@ -35,23 +35,19 @@ T = TypeVar("T") -class MLflowConfig(ConfigModel): +class MLflowConfig(EnvConfigMixin): tracking_uri: Optional[str] = Field( default=None, description="Tracking server URI. If not set, an MLflow default tracking_uri is used (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)", ) registry_uri: Optional[str] = Field( default=None, - description="Registry server URI", + description="Registry server URI. If not set, an MLflow default registry_uri is used (value of tracking_uri or `MLFLOW_REGISTRY_URI` environment variable)", ) model_name_separator: str = Field( default="_", description="A string which separates model name from its version (e.g. model_1 or model-1)", ) - env: str = Field( - default=builder.DEFAULT_ENV, - description="Environment to use in namespace when constructing URNs", - ) @dataclass @@ -68,24 +64,8 @@ class MLflowRegisteredModelStageInfo: SourceCapability.DESCRIPTIONS, "Extract descriptions for MLflow Registered Models and Model Versions", ) -@capability(SourceCapability.TAGS, "Extract tags for MLflow Model Stages") +@capability(SourceCapability.TAGS, "Extract tags for MLflow Registered Model Stages") class MLflowSource(Source): - """ - ### Concept Mapping - - This ingestion source maps the following MLflow Concepts to DataHub Concepts: - - | Source Concept | DataHub Concept | Notes | - |:---------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------:|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| - | [`Registered Model`](https://mlflow.org/docs/latest/model-registry.html#concepts) | [`MlModelGroup`](https://datahubproject.io/docs/generated/metamodel/entities/mlmodelgroup/) | The name of a Model Group is the same as a Registered Model's name (e.g. my_mlflow_model) | - | [`Model Version`](https://mlflow.org/docs/latest/model-registry.html#concepts) | [`MlModel`](https://datahubproject.io/docs/generated/metamodel/entities/mlmodel/) | The name of a Model is `{registered_model_name}{model_name_separator}{model_version}` (e.g. my_mlflow_model_1 for Registered Model named my_mlflow_model and Version 1, my_mlflow_model_2, etc.) | - | [`Model Stage`](https://mlflow.org/docs/latest/model-registry.html#concepts) | [`Tag`](https://datahubproject.io/docs/generated/metamodel/entities/tag/) | The mapping between Model Stages and generated Tags is the following:
- Production: mlflow_production
- Staging: mlflow_staging
- Archived: mlflow_archived
- None: mlflow_none | - - ### Lightweight MLflow Support - - Besides a classic `mlflow` package, this plugin supports [`mlflow-skinny`](https://mlflow.org/docs/latest/quickstart_drilldown.html#python-library-options), which is a lightweight version of mlflow. To use it just replace `mlflow` with `mlflow-skinny` in the installation section below. - """ - platform = "mlflow" registered_model_stages_info = ( MLflowRegisteredModelStageInfo( @@ -122,7 +102,7 @@ def __init__(self, ctx: PipelineContext, config: MLflowConfig): def get_report(self) -> SourceReport: return self.report - def get_workunits(self) -> Iterable[WorkUnit]: + def get_workunits_internal(self) -> Iterable[WorkUnit]: yield from self._get_tags_workunits() yield from self._get_ml_model_workunits() @@ -152,10 +132,10 @@ def _create_workunit(self, urn: str, aspect: _Aspect) -> MetadataWorkUnit: """ Utility to create an MCP workunit. """ - mcp = MetadataChangeProposalWrapper(entityUrn=urn, aspect=aspect) - wu = MetadataWorkUnit(id=urn, mcp=mcp) - self.report.report_workunit(wu) - return wu + return MetadataChangeProposalWrapper( + entityUrn=urn, + aspect=aspect, + ).as_workunit() def _get_ml_model_workunits(self) -> Iterable[WorkUnit]: """ @@ -194,13 +174,12 @@ def _traverse_mlflow_search_func( Utility to traverse an MLflow search_* functions which return PagedList. """ next_page_token = None - all_pages_where_traversed = False - while not all_pages_where_traversed: + while True: paged_list = search_func(page_token=next_page_token, **kwargs) yield from paged_list.to_list() next_page_token = paged_list.token if not next_page_token: - all_pages_where_traversed = True + return def _get_ml_group_workunit(self, registered_model: RegisteredModel) -> WorkUnit: """ From e2b6ba5232726a96d284a67fd80521bb0c9d3328 Mon Sep 17 00:00:00 2001 From: hariishaa Date: Wed, 6 Sep 2023 12:51:03 +0300 Subject: [PATCH 3/7] Replace WorkUnit with MetadataWorkunit --- .../src/datahub/ingestion/source/mlflow.py | 22 +++-- .../mlflow/mlflow_mcps_golden.json | 90 +++++++++++++++++++ 2 files changed, 104 insertions(+), 8 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py index 1f18700b23030..cef6d2b1bb577 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py +++ b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py @@ -1,6 +1,6 @@ +from dataclasses import dataclass from typing import Any, Callable, Iterable, Optional, TypeVar, Union -from dataclasses import dataclass from mlflow import MlflowClient from mlflow.entities import Run from mlflow.entities.model_registry import ModelVersion, RegisteredModel @@ -10,7 +10,7 @@ import datahub.emitter.mce_builder as builder from datahub.configuration.source_common import EnvConfigMixin from datahub.emitter.mcp import MetadataChangeProposalWrapper -from datahub.ingestion.api.common import PipelineContext, WorkUnit +from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.decorators import ( SupportStatus, capability, @@ -102,11 +102,11 @@ def __init__(self, ctx: PipelineContext, config: MLflowConfig): def get_report(self) -> SourceReport: return self.report - def get_workunits_internal(self) -> Iterable[WorkUnit]: + def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: yield from self._get_tags_workunits() yield from self._get_ml_model_workunits() - def _get_tags_workunits(self) -> Iterable[WorkUnit]: + def _get_tags_workunits(self) -> Iterable[MetadataWorkUnit]: """ Create tags for each Stage in MLflow Model Registry. """ @@ -137,7 +137,7 @@ def _create_workunit(self, urn: str, aspect: _Aspect) -> MetadataWorkUnit: aspect=aspect, ).as_workunit() - def _get_ml_model_workunits(self) -> Iterable[WorkUnit]: + def _get_ml_model_workunits(self) -> Iterable[MetadataWorkUnit]: """ Traverse each Registered Model in Model Registry and generate a corresponding workunit. """ @@ -181,7 +181,10 @@ def _traverse_mlflow_search_func( if not next_page_token: return - def _get_ml_group_workunit(self, registered_model: RegisteredModel) -> WorkUnit: + def _get_ml_group_workunit( + self, + registered_model: RegisteredModel, + ) -> MetadataWorkUnit: """ Generate an MLModelGroup workunit for an MLflow Registered Model. """ @@ -234,7 +237,7 @@ def _get_ml_model_properties_workunit( registered_model: RegisteredModel, model_version: ModelVersion, run: Union[None, Run], - ) -> WorkUnit: + ) -> MetadataWorkUnit: """ Generate an MLModel workunit for an MLflow Model Version. Every Model Version is a DataHub MLModel entity associated with an MLModelGroup corresponding to a Registered Model. @@ -286,7 +289,10 @@ def _make_external_url(self, model_version: ModelVersion) -> Union[None, str]: else: return None - def _get_global_tags_workunit(self, model_version: ModelVersion) -> WorkUnit: + def _get_global_tags_workunit( + self, + model_version: ModelVersion, + ) -> MetadataWorkUnit: """ Associate a Model Version Stage with a corresponding tag. """ diff --git a/metadata-ingestion/tests/integration/mlflow/mlflow_mcps_golden.json b/metadata-ingestion/tests/integration/mlflow/mlflow_mcps_golden.json index 1464b2fc6d293..c70625c74d998 100644 --- a/metadata-ingestion/tests/integration/mlflow/mlflow_mcps_golden.json +++ b/metadata-ingestion/tests/integration/mlflow/mlflow_mcps_golden.json @@ -144,5 +144,95 @@ "lastObserved": 1615443388097, "runId": "mlflow-source-test" } +}, +{ + "entityType": "mlModel", + "entityUrn": "urn:li:mlModel:(urn:li:dataPlatform:mlflow,test-model_1,PROD)", + "changeType": "UPSERT", + "aspectName": "status", + "aspect": { + "json": { + "removed": false + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "mlModelGroup", + "entityUrn": "urn:li:mlModelGroup:(urn:li:dataPlatform:mlflow,test-model,PROD)", + "changeType": "UPSERT", + "aspectName": "status", + "aspect": { + "json": { + "removed": false + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "tag", + "entityUrn": "urn:li:tag:mlflow_staging", + "changeType": "UPSERT", + "aspectName": "status", + "aspect": { + "json": { + "removed": false + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "tag", + "entityUrn": "urn:li:tag:mlflow_archived", + "changeType": "UPSERT", + "aspectName": "status", + "aspect": { + "json": { + "removed": false + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "tag", + "entityUrn": "urn:li:tag:mlflow_production", + "changeType": "UPSERT", + "aspectName": "status", + "aspect": { + "json": { + "removed": false + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "tag", + "entityUrn": "urn:li:tag:mlflow_none", + "changeType": "UPSERT", + "aspectName": "status", + "aspect": { + "json": { + "removed": false + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } } ] \ No newline at end of file From df2f6bd1ab409f9cee42f78895840197b1c89cc3 Mon Sep 17 00:00:00 2001 From: hariishaa Date: Wed, 6 Sep 2023 13:35:09 +0300 Subject: [PATCH 4/7] Code polishing --- metadata-ingestion/setup.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index 0d571b435c9fe..8461d768fd12a 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -350,9 +350,7 @@ def get_long_description(): "looker": looker_common, "lookml": looker_common, "metabase": {"requests"} | sqllineage_lib, - "mlflow": { - "mlflow-skinny>=2.3.0", - }, + "mlflow": {"mlflow-skinny>=2.3.0"}, "mode": {"requests", "tenacity>=8.0.1"} | sqllineage_lib, "mongodb": {"pymongo[srv]>=3.11", "packaging"}, "mssql": sql_common | {"sqlalchemy-pytds>=0.3"}, From e6a851f140d783509ea378fdb259bd19db6ae99e Mon Sep 17 00:00:00 2001 From: Andrew Sikowitz Date: Fri, 22 Sep 2023 14:08:11 -0400 Subject: [PATCH 5/7] add mlflow to dev --- metadata-ingestion/setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index 0dac385f11b46..798df6ca54405 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -506,6 +506,7 @@ def get_long_description(): "nifi", "vertica", "mode", + "mlflow", ] if plugin for dependency in plugins[plugin] From 4041a1d2823e59a67aea851ff19ab7357f37a758 Mon Sep 17 00:00:00 2001 From: Andrew Sikowitz Date: Fri, 22 Sep 2023 15:19:57 -0400 Subject: [PATCH 6/7] mlflow only for >= 3.8 --- metadata-ingestion/setup.py | 2 +- .../src/datahub/ingestion/source/mlflow.py | 6 + .../integration/mlflow/test_mlflow_source.py | 184 +++++++++--------- 3 files changed, 98 insertions(+), 94 deletions(-) diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index 798df6ca54405..9073a20f9f84f 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -475,6 +475,7 @@ def get_long_description(): "elasticsearch", "feast" if sys.version_info >= (3, 8) else None, "iceberg" if sys.version_info >= (3, 8) else None, + "mlflow" if sys.version_info >= (3, 8) else None, "json-schema", "ldap", "looker", @@ -506,7 +507,6 @@ def get_long_description(): "nifi", "vertica", "mode", - "mlflow", ] if plugin for dependency in plugins[plugin] diff --git a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py index cef6d2b1bb577..0668defe7b0c6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py +++ b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py @@ -1,3 +1,9 @@ +import sys + +if sys.version_info < (3, 8): + raise ImportError("MLflow is only supported on Python 3.8+") + + from dataclasses import dataclass from typing import Any, Callable, Iterable, Optional, TypeVar, Union diff --git a/metadata-ingestion/tests/integration/mlflow/test_mlflow_source.py b/metadata-ingestion/tests/integration/mlflow/test_mlflow_source.py index 155199d5a04e9..76af666526555 100644 --- a/metadata-ingestion/tests/integration/mlflow/test_mlflow_source.py +++ b/metadata-ingestion/tests/integration/mlflow/test_mlflow_source.py @@ -1,106 +1,104 @@ -from pathlib import Path -from typing import Any, Dict, TypeVar +import sys -import pytest -from mlflow import MlflowClient +if sys.version_info >= (3, 8): + from pathlib import Path + from typing import Any, Dict, TypeVar -from datahub.ingestion.run.pipeline import Pipeline -from tests.test_helpers import mce_helpers + import pytest + from mlflow import MlflowClient -T = TypeVar("T") + from datahub.ingestion.run.pipeline import Pipeline + from tests.test_helpers import mce_helpers + T = TypeVar("T") -@pytest.fixture -def tracking_uri(tmp_path: Path) -> str: - return str(tmp_path / "mlruns") + @pytest.fixture + def tracking_uri(tmp_path: Path) -> str: + return str(tmp_path / "mlruns") + @pytest.fixture + def sink_file_path(tmp_path: Path) -> str: + return str(tmp_path / "mlflow_source_mcps.json") -@pytest.fixture -def sink_file_path(tmp_path: Path) -> str: - return str(tmp_path / "mlflow_source_mcps.json") - - -@pytest.fixture -def pipeline_config(tracking_uri: str, sink_file_path: str) -> Dict[str, Any]: - source_type = "mlflow" - return { - "run_id": "mlflow-source-test", - "source": { - "type": source_type, - "config": { - "tracking_uri": tracking_uri, + @pytest.fixture + def pipeline_config(tracking_uri: str, sink_file_path: str) -> Dict[str, Any]: + source_type = "mlflow" + return { + "run_id": "mlflow-source-test", + "source": { + "type": source_type, + "config": { + "tracking_uri": tracking_uri, + }, }, - }, - "sink": { - "type": "file", - "config": { - "filename": sink_file_path, + "sink": { + "type": "file", + "config": { + "filename": sink_file_path, + }, }, - }, - } - - -@pytest.fixture -def generate_mlflow_data(tracking_uri: str) -> None: - client = MlflowClient(tracking_uri=tracking_uri) - experiment_name = "test-experiment" - run_name = "test-run" - model_name = "test-model" - test_experiment_id = client.create_experiment(experiment_name) - test_run = client.create_run( - experiment_id=test_experiment_id, - run_name=run_name, - ) - client.log_param( - run_id=test_run.info.run_id, - key="p", - value=1, - ) - client.log_metric( - run_id=test_run.info.run_id, - key="m", - value=0.85, - ) - client.create_registered_model( - name=model_name, - tags=dict( - model_id=1, - model_env="test", - ), - description="This a test registered model", - ) - client.create_model_version( - name=model_name, - source="dummy_dir/dummy_file", - run_id=test_run.info.run_id, - tags=dict(model_version_id=1), - ) - client.transition_model_version_stage( - name=model_name, - version="1", - stage="Archived", - ) + } + @pytest.fixture + def generate_mlflow_data(tracking_uri: str) -> None: + client = MlflowClient(tracking_uri=tracking_uri) + experiment_name = "test-experiment" + run_name = "test-run" + model_name = "test-model" + test_experiment_id = client.create_experiment(experiment_name) + test_run = client.create_run( + experiment_id=test_experiment_id, + run_name=run_name, + ) + client.log_param( + run_id=test_run.info.run_id, + key="p", + value=1, + ) + client.log_metric( + run_id=test_run.info.run_id, + key="m", + value=0.85, + ) + client.create_registered_model( + name=model_name, + tags=dict( + model_id=1, + model_env="test", + ), + description="This a test registered model", + ) + client.create_model_version( + name=model_name, + source="dummy_dir/dummy_file", + run_id=test_run.info.run_id, + tags=dict(model_version_id=1), + ) + client.transition_model_version_stage( + name=model_name, + version="1", + stage="Archived", + ) -def test_ingestion( - pytestconfig, - mock_time, - sink_file_path, - pipeline_config, - generate_mlflow_data, -): - print(f"MCPs file path: {sink_file_path}") - golden_file_path = ( - pytestconfig.rootpath / "tests/integration/mlflow/mlflow_mcps_golden.json" - ) + def test_ingestion( + pytestconfig, + mock_time, + sink_file_path, + pipeline_config, + generate_mlflow_data, + ): + print(f"MCPs file path: {sink_file_path}") + golden_file_path = ( + pytestconfig.rootpath / "tests/integration/mlflow/mlflow_mcps_golden.json" + ) - pipeline = Pipeline.create(pipeline_config) - pipeline.run() - pipeline.pretty_print_summary() - pipeline.raise_from_status() + pipeline = Pipeline.create(pipeline_config) + pipeline.run() + pipeline.pretty_print_summary() + pipeline.raise_from_status() - mce_helpers.check_golden_file( - pytestconfig=pytestconfig, - output_path=sink_file_path, - golden_path=golden_file_path, - ) + mce_helpers.check_golden_file( + pytestconfig=pytestconfig, + output_path=sink_file_path, + golden_path=golden_file_path, + ) From a2df1aa5106afe2a595dd14190d36b1a33537ae2 Mon Sep 17 00:00:00 2001 From: Andrew Sikowitz Date: Mon, 25 Sep 2023 16:47:42 -0400 Subject: [PATCH 7/7] skip unit tests --- .../tests/unit/test_mlflow_source.py | 225 +++++++++--------- 1 file changed, 109 insertions(+), 116 deletions(-) diff --git a/metadata-ingestion/tests/unit/test_mlflow_source.py b/metadata-ingestion/tests/unit/test_mlflow_source.py index 374816055b216..97b5afd3d6a4e 100644 --- a/metadata-ingestion/tests/unit/test_mlflow_source.py +++ b/metadata-ingestion/tests/unit/test_mlflow_source.py @@ -1,140 +1,133 @@ -import datetime -from pathlib import Path -from typing import Any, TypeVar, Union - -import pytest -from mlflow import MlflowClient -from mlflow.entities.model_registry import RegisteredModel -from mlflow.entities.model_registry.model_version import ModelVersion -from mlflow.store.entities import PagedList - -from datahub.ingestion.api.common import PipelineContext -from datahub.ingestion.source.mlflow import MLflowConfig, MLflowSource - -T = TypeVar("T") - - -@pytest.fixture -def tracking_uri(tmp_path: Path) -> str: - return str(tmp_path / "mlruns") - - -@pytest.fixture -def source(tracking_uri: str) -> MLflowSource: - return MLflowSource( - ctx=PipelineContext(run_id="mlflow-source-test"), - config=MLflowConfig(tracking_uri=tracking_uri), - ) - - -@pytest.fixture -def registered_model(source: MLflowSource) -> RegisteredModel: - model_name = "abc" - return RegisteredModel(name=model_name) - - -@pytest.fixture -def model_version( - source: MLflowSource, - registered_model: RegisteredModel, -) -> ModelVersion: - version = "1" - return ModelVersion( - name=registered_model.name, - version=version, - creation_timestamp=datetime.datetime.now(), - ) - - -def dummy_search_func(page_token: Union[None, str], **kwargs: Any) -> PagedList[T]: - dummy_pages = dict( - page_1=PagedList(items=["a", "b"], token="page_2"), - page_2=PagedList(items=["c", "d"], token="page_3"), - page_3=PagedList(items=["e"], token=None), - ) - if page_token is None: - page_to_return = dummy_pages["page_1"] - else: - page_to_return = dummy_pages[page_token] - if kwargs.get("case", "") == "upper": - page_to_return = PagedList( - items=[e.upper() for e in page_to_return.to_list()], - token=page_to_return.token, - ) - return page_to_return - - -def test_stages(source): - mlflow_registered_model_stages = { - "Production", - "Staging", - "Archived", - None, - } - workunits = source._get_tags_workunits() - names = [wu.get_metadata()["metadata"].aspect.name for wu in workunits] +import sys - assert len(names) == len(mlflow_registered_model_stages) - assert set(names) == { - "mlflow_" + str(stage).lower() for stage in mlflow_registered_model_stages - } +if sys.version_info >= (3, 8): + import datetime + from pathlib import Path + from typing import Any, TypeVar, Union + import pytest + from mlflow import MlflowClient + from mlflow.entities.model_registry import RegisteredModel + from mlflow.entities.model_registry.model_version import ModelVersion + from mlflow.store.entities import PagedList -def test_config_model_name_separator(source, model_version): - name_version_sep = "+" - source.config.model_name_separator = name_version_sep - expected_model_name = ( - f"{model_version.name}{name_version_sep}{model_version.version}" - ) - expected_urn = f"urn:li:mlModel:(urn:li:dataPlatform:mlflow,{expected_model_name},{source.config.env})" + from datahub.ingestion.api.common import PipelineContext + from datahub.ingestion.source.mlflow import MLflowConfig, MLflowSource - urn = source._make_ml_model_urn(model_version) + T = TypeVar("T") - assert urn == expected_urn + @pytest.fixture + def tracking_uri(tmp_path: Path) -> str: + return str(tmp_path / "mlruns") + @pytest.fixture + def source(tracking_uri: str) -> MLflowSource: + return MLflowSource( + ctx=PipelineContext(run_id="mlflow-source-test"), + config=MLflowConfig(tracking_uri=tracking_uri), + ) -def test_model_without_run(source, registered_model, model_version): - run = source._get_mlflow_run(model_version) - wu = source._get_ml_model_properties_workunit( - registered_model=registered_model, - model_version=model_version, - run=run, - ) - aspect = wu.get_metadata()["metadata"].aspect + @pytest.fixture + def registered_model(source: MLflowSource) -> RegisteredModel: + model_name = "abc" + return RegisteredModel(name=model_name) + + @pytest.fixture + def model_version( + source: MLflowSource, + registered_model: RegisteredModel, + ) -> ModelVersion: + version = "1" + return ModelVersion( + name=registered_model.name, + version=version, + creation_timestamp=datetime.datetime.now(), + ) - assert aspect.hyperParams is None - assert aspect.trainingMetrics is None + def dummy_search_func(page_token: Union[None, str], **kwargs: Any) -> PagedList[T]: + dummy_pages = dict( + page_1=PagedList(items=["a", "b"], token="page_2"), + page_2=PagedList(items=["c", "d"], token="page_3"), + page_3=PagedList(items=["e"], token=None), + ) + if page_token is None: + page_to_return = dummy_pages["page_1"] + else: + page_to_return = dummy_pages[page_token] + if kwargs.get("case", "") == "upper": + page_to_return = PagedList( + items=[e.upper() for e in page_to_return.to_list()], + token=page_to_return.token, + ) + return page_to_return + + def test_stages(source): + mlflow_registered_model_stages = { + "Production", + "Staging", + "Archived", + None, + } + workunits = source._get_tags_workunits() + names = [wu.get_metadata()["metadata"].aspect.name for wu in workunits] + + assert len(names) == len(mlflow_registered_model_stages) + assert set(names) == { + "mlflow_" + str(stage).lower() for stage in mlflow_registered_model_stages + } + + def test_config_model_name_separator(source, model_version): + name_version_sep = "+" + source.config.model_name_separator = name_version_sep + expected_model_name = ( + f"{model_version.name}{name_version_sep}{model_version.version}" + ) + expected_urn = f"urn:li:mlModel:(urn:li:dataPlatform:mlflow,{expected_model_name},{source.config.env})" + urn = source._make_ml_model_urn(model_version) -def test_traverse_mlflow_search_func(source): - expected_items = ["a", "b", "c", "d", "e"] + assert urn == expected_urn - items = list(source._traverse_mlflow_search_func(dummy_search_func)) + def test_model_without_run(source, registered_model, model_version): + run = source._get_mlflow_run(model_version) + wu = source._get_ml_model_properties_workunit( + registered_model=registered_model, + model_version=model_version, + run=run, + ) + aspect = wu.get_metadata()["metadata"].aspect - assert items == expected_items + assert aspect.hyperParams is None + assert aspect.trainingMetrics is None + def test_traverse_mlflow_search_func(source): + expected_items = ["a", "b", "c", "d", "e"] -def test_traverse_mlflow_search_func_with_kwargs(source): - expected_items = ["A", "B", "C", "D", "E"] + items = list(source._traverse_mlflow_search_func(dummy_search_func)) - items = list(source._traverse_mlflow_search_func(dummy_search_func, case="upper")) + assert items == expected_items - assert items == expected_items + def test_traverse_mlflow_search_func_with_kwargs(source): + expected_items = ["A", "B", "C", "D", "E"] + items = list( + source._traverse_mlflow_search_func(dummy_search_func, case="upper") + ) -def test_make_external_link_local(source, model_version): - expected_url = None + assert items == expected_items - url = source._make_external_url(model_version) + def test_make_external_link_local(source, model_version): + expected_url = None - assert url == expected_url + url = source._make_external_url(model_version) + assert url == expected_url -def test_make_external_link_remote(source, model_version): - tracking_uri_remote = "https://dummy-mlflow-tracking-server.org" - source.client = MlflowClient(tracking_uri=tracking_uri_remote) - expected_url = f"{tracking_uri_remote}/#/models/{model_version.name}/versions/{model_version.version}" + def test_make_external_link_remote(source, model_version): + tracking_uri_remote = "https://dummy-mlflow-tracking-server.org" + source.client = MlflowClient(tracking_uri=tracking_uri_remote) + expected_url = f"{tracking_uri_remote}/#/models/{model_version.name}/versions/{model_version.version}" - url = source._make_external_url(model_version) + url = source._make_external_url(model_version) - assert url == expected_url + assert url == expected_url