From a9271d5e96a3871027d47c4c271ba5e174fcf253 Mon Sep 17 00:00:00 2001 From: armaganngul Date: Sun, 1 Dec 2024 15:28:41 -0500 Subject: [PATCH] Changed testing --- .../app/__pycache__/__init__.cpython-312.pyc | Bin 198 -> 173 bytes .../__pycache__/__init__.cpython-312.pyc | Bin 210 -> 185 bytes .../__pycache__/__init__.cpython-312.pyc | Bin 243 -> 218 bytes .../entities/__pycache__/user.cpython-312.pyc | Bin 850 -> 825 bytes .../__pycache__/__init__.cpython-312.pyc | Bin 213 -> 188 bytes .../db_connection_manager.cpython-312.pyc | Bin 1592 -> 1567 bytes .../__pycache__/__init__.cpython-312.pyc | Bin 360 -> 335 bytes .../__pycache__/csv_file_repo.cpython-312.pyc | Bin 12162 -> 12137 bytes .../sqlite_db_repo.cpython-312.pyc | Bin 12014 -> 11989 bytes .../__pycache__/__init__.cpython-312.pyc | Bin 726 -> 701 bytes .../__pycache__/generate.cpython-312.pyc | Bin 2998 -> 3063 bytes .../__pycache__/get_headers.cpython-312.pyc | Bin 2119 -> 2094 bytes .../get_last_login_data.cpython-312.pyc | Bin 1854 -> 1829 bytes .../get_values_under_header.cpython-312.pyc | Bin 3753 -> 3728 bytes .../__pycache__/upload_data.cpython-312.pyc | Bin 1657 -> 1632 bytes backend/app/use_cases/generate.py | 5 +- .../datapoint_entity.cpython-312.pyc | Bin 2248 -> 2223 bytes backend/ml_model/model_with_score.pkl | Bin 2314 -> 2314 bytes .../data_preprocessing.cpython-312.pyc | Bin 3510 -> 3485 bytes .../__pycache__/file_reader.cpython-312.pyc | Bin 2869 -> 2844 bytes .../__pycache__/model_saver.cpython-312.pyc | Bin 1191 -> 3567 bytes backend/ml_model/repository/model_saver.py | 72 ++++- .../ml_model/repository/multiple_models.py | 197 +++++++----- .../ml_model/repository/safe_grid_search.py | 59 ++++ backend/ml_model/repository/safe_split.py | 28 +- .../__pycache__/model.cpython-312.pyc | Bin 3389 -> 10257 bytes backend/ml_model/use_cases/model.py | 288 ++++++++++++++---- .../ml_model/use_cases/multiple_model_use.py | 5 +- .../tests/repositories/test_model_saver.py | 5 +- .../repositories/test_safe_train_grid.py | 15 +- backend/tests/use_cases/test_model.py | 21 +- .../use_cases/test_multiple_model_use.py | 4 +- 32 files changed, 515 insertions(+), 184 deletions(-) diff --git a/backend/app/__pycache__/__init__.cpython-312.pyc b/backend/app/__pycache__/__init__.cpython-312.pyc index 35ccb444efa3e08bd0ce7d2b4a53e1712dd64478..f60fb9536d81df9378d50a22e4b1b8b0dcab3d15 100644 GIT binary patch delta 47 zcmX@cxR#OoG%qg~0}y2T8%^XkmhjOJElw>e)=wc#M(zG%qg~0}z}_HJr$8Y?!JaTAW%`te=&bnU!Chk*V*JTAW>yU!d>hrQljo aQk1CRTv}9=nOvHaSfrm^9A7XoEF1tUSr{|` diff --git a/backend/app/controllers/__pycache__/__init__.cpython-312.pyc b/backend/app/controllers/__pycache__/__init__.cpython-312.pyc index e5f62b9d7f1767733d6d9eeabcf47c91d6643c19..50a67483fece5134d383b277f55fc0df0c460198 100644 GIT binary patch delta 47 zcmcb_xRa6lG%qg~0}y2T8%^XkmWa>~Elw>e)=wyU!d>hrQljo aQk1CRTv}9=nOvHaSfrm^9A7XoEExbW92ijm diff --git a/backend/app/entities/__pycache__/__init__.cpython-312.pyc b/backend/app/entities/__pycache__/__init__.cpython-312.pyc index 6db9f48abd968586e854c56f843ea3293fb8d9ae..5190fee9d61fbaed74fd036f1db2e671a5d293b0 100644 GIT binary patch delta 47 zcmey&c#DzyG%qg~0}y2T8%^X6kqFffElw>e)=wyU!d>hrQljo aQk1CRTv}9=nOvHaSfrm^9A7Z8JQ)Bs9vFWB diff --git a/backend/app/entities/__pycache__/user.cpython-312.pyc b/backend/app/entities/__pycache__/user.cpython-312.pyc index f98caf19b238fa48a7989ce472c1a4fd4e0ad681..267ab79a0bb5d769eeb906a2ad1bb28210ee1799 100644 GIT binary patch delta 50 zcmcb_wv&zfG%qg~0}!+=cH7APnNcE8KeRZts8~O-C^s=ZF;73BGC2dt4%n>Eq|68a Dh5-+} delta 75 zcmdnVc8QJqG%qg~0}$xV_Sne%nb9yyKeRZts8~NMF*7T_I3rWvCAB!aB)>r4%S*wv dqNFHM!MU`kC^NYyBLIl}7@hzC diff --git a/backend/app/infrastructure/__pycache__/__init__.cpython-312.pyc b/backend/app/infrastructure/__pycache__/__init__.cpython-312.pyc index dfc1c4d8dbeb3350877e71c7520b971d2db06516..7fe7f1d0e0085ca863e54fa5c8b435cd4ee075ed 100644 GIT binary patch delta 47 zcmcc0xQCJZG%qg~0}y2T8%^XkmWb95Elw>e)=we*3U}J%*rp$$kca9EzT~9di0Ewe<1erR!OQL%nvQEp;-VxE3LWpW0P9kBT&vp*96 Df`|}* delta 75 zcmbQwvxA5GG%qg~0}$vg^4iF~mf5gTKeRZts8~NMF*7T_I3rWvCAB!aB)>r4%S*wv dqNFHM!MU`kC^NYIX AvH$=8 delta 72 zcmX@l^n!`|G%qg~0}$vg@|wuK*RV)Gv^ce>SU)Q%1kcJNi5P&E{-qQY{M}}p7GY^o9YwT0Z1$ySpWb4 diff --git a/backend/app/repositories/__pycache__/sqlite_db_repo.cpython-312.pyc b/backend/app/repositories/__pycache__/sqlite_db_repo.cpython-312.pyc index 145f3bef433321741bcd3a7202acb82718fe9a9f..21f21b23c57c8545c0d36e2caed3a31b853808af 100644 GIT binary patch delta 50 zcmaDCdo`B(G%qg~0}!lQ>9di0F1tj$erR!OQL%nvQEp;-VxE3LWpW0P9kBT(yQn$< DvWO8^ delta 75 zcmcZ_`!1IIG%qg~0}$vg^4iEfm))>JKeRZts8~NMF*7T_I3rWvCAB!aB)>r4%S*wv dqNFHM!MU`kC^NY9dhrj!7a+KeRZts8~O-C^s=ZF;73BGC2dt4wxLpR09BI CP7j6v delta 74 zcmdnXdX1I)G%qg~0}$vg^4iEP$7GnVA6lGRRIHzsn3VtQhpen4e%29O;vc_qsZ#v7B9SXVP;@lLj8lQODhtzm6oN@u8L1!A@uRu_iY zLI#Ff_8PVnLSGwq@68=2`D5AVscag-38=`L**r*v@}R^@<%pV zUe?u&5YreLYFH;rvllb+P3~uR5*8`}IXylxFC{*;EHS4vu_Sf!1@;8%a%o9w(d0x9304lEgu>)n4jsmulUH##2}!EXmh0fVAt--c z(CDI|(PcrC$saiGGHOje$LU}s#VU3~Mqz>QOwSL@Ahv~D&DV?E~ z1&CQ|SX>xl3mF({*=ks`ctO^og30=9Vw3sVI9O`fvIHmpWtCz}VFEIRK}?P+p!JYE0(maxju)6}us$ut0dG z=LcpG+rVT)@D7z7#XB`O#9cISfe8BegXssHykA%)CpU66G0IJ5;m%+*n4HTk!DutN bjXOrv=qn2oBhz;-T}H7_3_zlY4`?p{xS?#c diff --git a/backend/app/use_cases/__pycache__/get_headers.cpython-312.pyc b/backend/app/use_cases/__pycache__/get_headers.cpython-312.pyc index 037e861beb960633881b56140509eaf0bb7e1044..7d20b539e97970370f908b6ccc616e54846fa607 100644 GIT binary patch delta 50 zcmX>uuug#cG%qg~0}!lQ>9djh6thI6erR!OQL%nvQEp;-VxE3LWpW0P9k7{=#gPR7 Dh!79O delta 75 zcmZ1{a9n`r4%S*wv dqNFHM!MU`kC^NY9djBjYT3+KeRZts8~O-C^s=ZF;73BGC2dt4%pnpvWE!( DcS#T_ delta 75 zcmZ3=w~vqeG%qg~0}$vg^4iGl#$s5dA6lGRRIHzsn39di04YNe5erR!OQL%nvQEp;-VxE3LWpW0P9kBTY^Fb~E DgPRce delta 75 zcmbOryHb|>G%qg~0}$vg^4iF~hS{)AKeRZts8~NMF*7T_I3rWvCAB!aB)>r4%S*wv dqNFHM!MU`kC^NY9di03$sL|erR!OQL%nvQEp;-VxE3LWpW0P9kBTWb0-r3 Dnr4%S*wv dqNFHM!MU`kC^NYhxL%O^G%qg~0}!lQ>9dhrmq{W?KeRZts8~O-C^s=ZF;73BGC2dt4%nQ@G?9hz K*XG}BXBh$Zd=p{- delta 97 zcmZ24ctVi-G%qg~0}$vg^4iF)%Vb!sA6lGRRIHzsn3Jr-UnM2%YN@-FLYjJXZQR)Jr-UnM2%gN@-FLYjJXZQR)9djBpG6`~KeRZts8~O-C^s=ZF;73BGC2dt4%pnyqRj~a DdD9QT delta 75 zcmbO$y-k|?G%qg~0}$vg^4iGl&th1wA6lGRRIHzsn39dhrm{lTPKeRZts8~O-C^s=ZF;73BGC2dt4%qC^`ivC- DZ|V>3 delta 75 zcmbOuwpEP#G%qg~0}$vg^4iEP%xYMnA6lGRRIHzsn3+0+Yr4C7P)vg&a1`gp=R*W@mP{tr89kj9#2)3Z1NABjAWId>jbyW(-z}K>BzfAND}s`#$@FT4C=%p z$YESH72u#+nwhYYQ(dhG!m+8r`ZH_SizzdCMH^3>X|n^qoz+yfGsv9uOv^dVE?S)Z z0s>}eOT%SB{gOp}&TQ8*=?qd+=9<*@1+^|Rr^0;ObEz;qZc!#Eqh;Heb}YI-U!Hcd zv}ik4KGSyXDC2C#@+~d~+tzc<*0w%r312syA|0fywrZn6sn4tt)SH~L41a*mmTaR$ zZII5Twkv$*8W#16ZP!u5bA8a#cF#+dkt>CNdxrTQr>qD9(^9o;@p2?T^(#oeO1(?E zk4Fb-*~}khK0A)ut}iUonLr{epB63VS2)TXWA&|8FDqhtVUZ*W?TVh>C z%JQ)D|zVZq^8t}4x_NuRG2C8ECI7dY!d>p zt+GF1Gd9dyuI<~*vF94GbTW=OYB)>?it2@tH9)27M6NI(BTX)aPuKD=>jhCWesVGQ zNZbDko~Iyi9v?JP$QCDk@_>0vupomc4wNWcj${)Y^EL0lE)88)3p8Ol#Z}7!y$nx1 z(hKhh$bkjMXV`h>7DlT^3C5$`!=xn&N0=z(4wuWhA)ATZP(@(IEKD+E+H%dpjHAoo zC~&Ln2^;;t5Jro@WS%cqgS4(gI`egXuD9uP`Dk$m$RV~v_a0e(o^0PeuYK74)0Um^ z*&sh3ZCHU#GGfGz4l`6yh6RUJ2n#ZWG^^SaQmkq~*&1GrU097ku7y4V%FBw6hLmke zf^sx#taZ1~apuRsccUTqW%V()Y+idy4L~-QBH{N&6g2?eX#Ev(S@Zj%ovAI+sQF#r zIyJ>ie5%&sazKI70RN(U6;jpu$;(72v#LVIhys6A?NVZl%tG0~uC{{o{{YZyl6{m^ z;*7E;$z$L=9SOcgM#!4Y>LS!3-p$EUVL~Dbpa~!%bBv0^e1M#ahzfvcA4-*yXGY*M z5hc_K+Fo7wRVaLz<5Vd`B;YF`5Y!t6fDtpQ63=C}w~Ig}p-@Jv2vG?WLMRoStT;Z( z(5Ry%95!kbEU!#kA}IsBlfdE6dUV>VwrCq(wnyH-XHY`o8mXIX8`RPm4eF@yi0irr zu*w@Ltom4KMY*_n)!f#6zLA95Av;iKC$nmhlotd^$!jFr6{bdvzh>GDR8It2nfaxl z1CSl-08C_P1ziucDbIET%`01OkO~bHXl7-moK14f5F9TH(pUlLa5~6Hj=H=~DT*ZQ zqP~=GK)MX3$}o-2^_`4iIp17H^cNpM2T5`d+1yv}Ug-XDQ~%vf{fkGh4__O;`Rwf- zw{|RTdVN0mTW??ez`}w0p@l=Ay|&alFrU8Px3zwD;cWff!nwu$-wzyH>N`H4d5|K{ z^nCK-RdcDkf0-zq$CSkj_q%#N9)5p#aqsoqwcM94+}-x(SG#}cI{Gky+si4^x$$8~ zJbWau=D$Ps4&@Sea>=1rJHWCjE^D-rYzB>mMgwDIbkwaP4cDdX7oY@e>~!e5=^47t zF;fR>*$g`4N;OD@g(OEdd>3?|6ONfy+Ki(KWrsi~&^}klUusV>y}WO8a@XR{Wdfg@ z{C4eDO%A_4_1)MvV=K7&VBd4ejp0>zOSXge;APvOgU5s)6-@AcoNN{F1A=ikAHE)- zC;%e-%8=rQ670j1u~d)J2SZKY;91zf%r8*#<%FUr_ekeG(tVHg{!Ctam^!Q|8~-NQ H%j5q5pCh#CJDuGrS+WUm00W+KM~Sh~41EW__r3RyAK$&l z`@38&0wMFiF8yRH0KAn>F6I~?SK3@{8qb7Wx?X1>x?g;heWr8jj5uu62S5jB|1 zNlq=+MZ9wA=Lj1ZuA8{RdYsR4e?0-Xco%aQxFLIUrAxAJQt^yvZW?t1VEsTVrO+wIO~mYl6MF^ zAG2?1|HD%HoF&yT{8WXWl=B^_jz3j*4Qv_7RGtaW<>=I=AXTyYL`jXDOH4zP>(uXS zm%vrPU{`@)6+p52T2EaCU0Bl}ep9rBte{Q8ti&T$E5anS5VMdKZ-mymAChdx!)m>5 zXGy2)C!WLHh?2V4!j%j?-HniN4R=~`tw7}{s4kTp>4?m9JCr(thZjeMWW;ct$r;I02s!E5KQKa> z9Uwo5nCIKK3EN1b*j`$)?(`ChtVO}%yPXhG+htBnqky;xvlkKb=5NLEe8Wdxh&DbzL?7D4&K?()v^ zVBzV){^XVIJMX5Zce^{?-A6l*p7i#o=C>E^y;Bm*)8pq2Hx)%-h6re<@7)1jbl}0h6d0}$Aui7FnJeW&3@g` z)jxH;Q7MR-$PoQPDA8}!#)xqx#=Ocz=Min!M_pwk6b?8T{k^U#qWHGx&Pt!3UY87sdM@0R;INe*gdg diff --git a/backend/ml_model/repository/model_saver.py b/backend/ml_model/repository/model_saver.py index 1b02e75e..5f324dff 100644 --- a/backend/ml_model/repository/model_saver.py +++ b/backend/ml_model/repository/model_saver.py @@ -1,22 +1,72 @@ import os import pickle - import pandas as pd from sklearn.model_selection import GridSearchCV -def save_model(best_clf: GridSearchCV, x_test: pd.DataFrame, y_test: pd.Series) -> None: +class ModelSaver: """ - Saves the model as a pkl file + Saves the trained model and its evaluation score as a pickle (.pkl) file. + + Parameters: + ----------- + best_clf : GridSearchCV + The trained model object, which is an instance of GridSearchCV containing the best estimator after hyperparameter tuning. + + x_test : pd.DataFrame + The test dataset features used for evaluating the model. + + y_test : pd.Series + The actual labels corresponding to the test dataset. + + Returns: + -------- + None + The function saves the model and its score to a file named `model_with_score.pkl` in the parent directory of the current script location. """ - # Overall model score - score = best_clf.score(x_test, y_test) + def __init__(self, best_clf: GridSearchCV, x_test: pd.DataFrame, y_test: pd.Series): + """ + Initializes the ModelSaver class with model, test features, and test labels. + + Parameters: + ----------- + best_clf : GridSearchCV + The trained model object, which is an instance of GridSearchCV containing the best estimator after hyperparameter tuning. + + x_test : pd.DataFrame + The test dataset features used for evaluating the model. + + y_test : pd.Series + The actual labels corresponding to the test dataset. + """ + self.best_clf = best_clf + self.x_test = x_test + self.y_test = y_test + + def save_model(self) -> None: + """ + Saves the trained model and its evaluation score as a pickle (.pkl) file. + + Returns: + -------- + None + The function saves the model and its score to a file named `model_with_score.pkl` in the parent directory of the current script location. + + Notes: + ------ + - The `score` is calculated using the `score` method of the `best_clf` object, which typically represents accuracy for classification models. + - The resulting pickle file contains a dictionary with two keys: + - "model": the `best_clf` object. + - "score": the evaluation score of the model on the test data. + """ + # Overall model score + score = self.best_clf.score(self.x_test, self.y_test) - curr_dir = os.path.dirname(__file__) - model_path = os.path.join(curr_dir, "../model_with_score.pkl") + curr_dir = os.path.dirname(__file__) + model_path = os.path.join(curr_dir, "../model_with_score.pkl") - # Save the model and its score - with open(model_path, "wb") as f: - pickle.dump({"model": best_clf, "score": score}, f) + # Save the model and its score + with open(model_path, "wb") as f: + pickle.dump({"model": self.best_clf, "score": score}, f) - return + return diff --git a/backend/ml_model/repository/multiple_models.py b/backend/ml_model/repository/multiple_models.py index cb4fa8e7..722815b9 100644 --- a/backend/ml_model/repository/multiple_models.py +++ b/backend/ml_model/repository/multiple_models.py @@ -12,95 +12,124 @@ from backend.ml_model.repository.data_preprocessing_multiple_models import DataProcessorMultiple from backend.ml_model.repository.fairness import FairnessEvaluator from backend.ml_model.repository.file_reader_multiple_models import FileReaderMultiple -from backend.ml_model.repository.safe_train_grid import safe_train_test_split +from backend.ml_model.repository.safe_split import SafeSplitter current_dir = os.path.dirname(os.path.abspath(__file__)) csv_file_path = os.path.join(current_dir, "../../../database/output.csv") -def evaluate_multiple_models(model_files): - file_reader = FileReaderMultiple(csv_file_path) - df_dropped, inputs, target = file_reader.read_file() - - data_processor = DataProcessorMultiple(inputs) - inputs_encoded = data_processor.encode_categorical_columns() - inputs_n = data_processor.drop_categorical_columns() - - split_data = safe_train_test_split(inputs_n, target) - if split_data is None: - return {} - - x_train, x_test, y_train, y_test = split_data - - # Debug: Print columns of x_test - print(f"Columns in x_test: {x_test.columns}") - - # Update sensitive features list to match the actual column names in x_test - sensitive_features_list = ["gender_N", "age_groups_N", "race_N", "state_N"] - - # Dictionary to store results - results = {} - - for model_file in model_files: - try: - # Load the model - with open(model_file, "rb") as f: - model_dict = pickle.load(f) - model = model_dict["model"] - print(f"Loaded object type: {type(model)}") - print(f"Loaded object: {model}") - - # Make predictions - y_pred = model.predict(x_test) - - # Debug: Check predictions - print(f"True Labels (y_test): {y_test.head()}") - print(f"Predicted Labels (y_pred): {y_pred[:10]}") - - # Dictionary and list to store fairness results for the model - model_results = {} - fairness_values = [] - - for sensitive_feature in sensitive_features_list: - # Select the sensitive feature column for the evaluation - if sensitive_feature in x_test.columns: - sensitive_col = x_test[sensitive_feature] - print( - f"Sensitive Feature: {sensitive_feature}, " - f"Column Data: {sensitive_col.head()}" - ) - - # Evaluate fairness for this specific sensitive feature - fairness_evaluator = FairnessEvaluator( - y_test, y_pred, sensitive_col - ) - metric_frame = fairness_evaluator.evaluate_fairness() - - average = sum(metric_frame.by_group["accuracy"]) / len( - metric_frame.by_group["accuracy"] +class MultiModelEvaluator: + """ + A class for evaluating multiple machine learning models on a dataset. + + This class handles data preprocessing, model loading, and fairness evaluation + based on specified sensitive features, producing metrics for each model. + """ + def __init__(self, model_files: list): + """ + Initializes the ModelEvaluator with the file path to the CSV and model files. + + Parameters: + ----------- + + model_files : list + List of file paths to the model pickle files to evaluate. + """ + self.model_files = model_files + + def evaluate_models(self) -> dict: + """ + Evaluates multiple models for fairness and performance metrics. + + Returns: + -------- + dict + A dictionary containing evaluation results for each model file. + """ + # Read and preprocess the data + file_reader = FileReaderMultiple(csv_file_path) + df_dropped, inputs, target = file_reader.read_file() + + data_processor = DataProcessorMultiple(inputs) + inputs_encoded = data_processor.encode_categorical_columns() + inputs_n = data_processor.drop_categorical_columns() + + data_splitter = SafeSplitter() + split_data = data_splitter.train_test_split(inputs_n, target) + if split_data is None: + return {} + + x_train, x_test, y_train, y_test = split_data + + # Debug: Print columns of x_test + print(f"Columns in x_test: {x_test.columns}") + + # Update sensitive features list to match the actual column names in x_test + sensitive_features_list = ["gender_N", "age_groups_N", "race_N", "state_N"] + + # Dictionary to store results + results = {} + + for model_file in self.model_files: + try: + # Load the model + with open(model_file, "rb") as f: + model_dict = pickle.load(f) + model = model_dict["model"] + print(f"Loaded object type: {type(model)}") + print(f"Loaded object: {model}") + + # Make predictions + y_pred = model.predict(x_test) + + # Debug: Check predictions + print(f"True Labels (y_test): {y_test.head()}") + print(f"Predicted Labels (y_pred): {y_pred[:10]}") + + # Dictionary and list to store fairness results for the model + model_results = {} + fairness_values = [] + + for sensitive_feature in sensitive_features_list: + # Select the sensitive feature column for the evaluation + if sensitive_feature in x_test.columns: + sensitive_col = x_test[sensitive_feature] + print( + f"Sensitive Feature: {sensitive_feature}, " + f"Column Data: {sensitive_col.head()}" + ) + + # Evaluate fairness for this specific sensitive feature + fairness_evaluator = FairnessEvaluator( + y_test, y_pred, sensitive_col + ) + metric_frame = fairness_evaluator.evaluate_fairness() + + average = sum(metric_frame.by_group["accuracy"]) / len( + metric_frame.by_group["accuracy"] + ) + rounded = round(average, 3) + + fairness_values.append(average) + + # Store the metrics for this demographic + demo_name = sensitive_feature.replace("_N", "") + model_results[demo_name] = rounded + else: + print( + f"Sensitive feature '{sensitive_feature}' " + f"not found in x_test." + ) + + if fairness_values: + model_results["variance"] = round(pd.Series(fairness_values).var(), 7) + model_results["mean"] = round( + sum(fairness_values) / len(fairness_values), 3 ) - rounded = round(average, 3) - - fairness_values.append(average) - - # Store the metrics for this demographic - demo_name = sensitive_feature.replace("_N", "") - model_results[demo_name] = rounded - else: - print( - f"Sensitive feature '{sensitive_feature}' " - f"not found in x_test." - ) - - if fairness_values: - model_results["variance"] = round(pd.Series(fairness_values).var(), 7) - model_results["mean"] = round( - sum(fairness_values) / len(fairness_values), 3 - ) - # Store results for the model - results[model_file] = model_results + # Store results for the model + results[model_file] = model_results - except Exception as e: - results[model_file] = f"Error: {e}" + except Exception as e: + results[model_file] = f"Error: {e}" - return results + return results diff --git a/backend/ml_model/repository/safe_grid_search.py b/backend/ml_model/repository/safe_grid_search.py index e69de29b..a3156f8d 100644 --- a/backend/ml_model/repository/safe_grid_search.py +++ b/backend/ml_model/repository/safe_grid_search.py @@ -0,0 +1,59 @@ +from sklearn.model_selection import GridSearchCV +from sklearn.tree import DecisionTreeClassifier +import pandas as pd + + +class SafeGridSearch: + """ + A utility class for safely performing grid search for hyperparameter tuning. + + This class ensures that grid search is executed while handling cases + where there are insufficient samples for cross-validation. + """ + + def __init__(self, classifier=DecisionTreeClassifier(), param_grid=None): + """ + Initializes the SafeGridSearch with a classifier and a parameter grid. + + Parameters: + ----------- + classifier : estimator object, optional (default=DecisionTreeClassifier()) + The base classifier to use for grid search. + + param_grid : dict, optional + The parameter grid to use for tuning hyperparameters. If None, + a default parameter grid for DecisionTreeClassifier is used. + """ + self.classifier = classifier + self.param_grid = param_grid or { + "criterion": ["gini", "entropy"], + "max_depth": [None] + list(range(1, 11)), + "min_samples_split": [2, 5, 10], + } + + def perform_search(self, x_train: pd.DataFrame, y_train: pd.Series): + """ + Performs a safe grid search for hyperparameter tuning. + + Parameters: + ----------- + x_train : pd.DataFrame + The training feature set. + + y_train : pd.Series + The training labels. + + Returns: + -------- + estimator or None + Returns the best estimator if grid search is successful. + Returns None if there are insufficient samples for cross-validation. + """ + try: + grid_search = GridSearchCV(self.classifier, self.param_grid, cv=5, scoring="accuracy") + grid_search.fit(x_train, y_train) + return grid_search.best_estimator_ + except ValueError as e: + if "Cannot have number of splits n_splits" in str(e): + print("Not enough samples for cross-validation. Returning None.") + return None diff --git a/backend/ml_model/repository/safe_split.py b/backend/ml_model/repository/safe_split.py index 4849877a..b22bbd76 100644 --- a/backend/ml_model/repository/safe_split.py +++ b/backend/ml_model/repository/safe_split.py @@ -4,14 +4,28 @@ class SafeSplitter: """ - A class for safely splitting datasets into training and testing subsets. + A utility class for safely splitting datasets into training and testing subsets. This class ensures that a dataset is properly split while handling cases where the sample size is too small to perform the split. """ - @staticmethod - def train_test_split(inputs: pd.DataFrame, target: pd.Series, test_size=0.2, random_state=48): + def __init__(self, test_size=0.2, random_state=48): + """ + Initializes the SafeSplitter with parameters for splitting. + + Parameters: + ----------- + test_size : float, optional (default=0.2) + Proportion of the dataset to include in the test split. + + random_state : int, optional (default=48) + Controls the shuffling applied to the data before splitting. + """ + self.test_size = test_size + self.random_state = random_state + + def train_test_split(self, inputs: pd.DataFrame, target: pd.Series): """ Splits the dataset into training and testing subsets safely. @@ -23,12 +37,6 @@ def train_test_split(inputs: pd.DataFrame, target: pd.Series, test_size=0.2, ran target : pd.Series Target labels of the dataset. - test_size : float, optional (default=0.2) - Proportion of the dataset to include in the test split. - - random_state : int, optional (default=48) - Controls the shuffling applied to the data before splitting. - Returns: -------- tuple or None @@ -37,7 +45,7 @@ def train_test_split(inputs: pd.DataFrame, target: pd.Series, test_size=0.2, ran """ try: x_train, x_test, y_train, y_test = train_test_split( - inputs, target, test_size=test_size, random_state=random_state + inputs, target, test_size=self.test_size, random_state=self.random_state ) return x_train, x_test, y_train, y_test except ValueError as e: diff --git a/backend/ml_model/use_cases/__pycache__/model.cpython-312.pyc b/backend/ml_model/use_cases/__pycache__/model.cpython-312.pyc index 56f8660faa9b43c5786ad85f4e651bb0c808e13c..c7279b418d30e50c97852e3c841f14e8251dde3a 100644 GIT binary patch literal 10257 zcmcIqU2GItcCP*}_pj{+{0F9h!MMS04;UDh46tT~pGg>o#UxpE&{WD@)pnceuFkD$ zvqs%oHk#EMDH6;R5Jeg9(|UxIB~RfY53_HxTCLPQJ*&2A7YQlZhmrC?AX222BITTW ztE#KpZS2WU$+~^3?z!ild+&G7J?C8iDi#ZK@ErN)EAyiVIPPEY#ePCo<@F9P$35g^ zPWGr=hW&an9(ebv-hwaV^Kc5c$lKjozl)Jp-L7py8ME74znf4dS`>Co;h%U|UG~Yo zCq7oUzh!z9Pgm)7+WlLRrS*UsE5tK#S`Vs;LT{#*)%5vQ?8xk(^&WL+VOM4st%uct!eC~Q)+6drVK_5P>rr*2usgGR)Aj2%*X^&gvnuwS zGJ7^z({8K3M~*##ea9y=>g6tRa{LocPB`E8E_+hF-(d6Al*jZNL{aW|(`l&h{;;eX z`HPBVl!@}bWJuSve9?H1zz2;$`<~f+ab8t6Cy3EbM9V6=PJ8rSl=7qqZx@y%wQR8o zU(Bn@bw!dDWE%cZlNEJVTB2{!St+N?mejmq(DyjLUm|&VR*^_{0qF}!FP4qGdS`b! zJ;nZHkakYel_{-kl*&drt1o47iuPkci+g<)zB~ljqZ(wqvM1wH{IXZ}<$SXL6Hg`} z2R`94K{*IzNDe{C%RH1lau1YYISge)jzAfeqfo|{W2v|qrE|DJqA`6spI!rm=Sjg*^ zu+|_<#+;_A+HLHq8%jw(O^J`Dg)5pQ>y&YcDAq)*Y2i+!Dc!NOKzEX{!tVmWNpsB> z1cTKuD1XJ!S-0u5aE(n=7v^E-1)Xla08{TiZPx90TKG`91ww#Caz1P1wW7uFL|U-m z2d(z92vU}4us6C=)bmDuNfB}k_g1HqAUt4_$xfHbnXW|OGo22Kw3OG&lA6CG;fPF> z`M-UP&o9q?^ERL@>$gD&zo?jq#n)*l9&!~=g%hF1o{AUVEQ82#6(77=ip+7Kv|i~K zHnI?19~41u9N7}mjX~jd-dGTd3M!&RmIVhXFgv<8$l#g{A8TLn+hpr;=s=gm9H^{V zl7Q)HfC6#au+I$yGdAJgLjAqjKNFB%$ZAEy0ycifCDRqdB>`@}sKX3^YfjZ9V?xkM zxQ>!4jLS++0(*4k?ewwK#_@<&(g>1L{b_MWA?6gJtOLSvdn|mp1|^^mT0zte&_oy) zAlYry zA-yp-Sc>x!BqeE<;Bk3|v+9TA1B{_y>&Kh+hYPf4U8%}d3pYs=XMs9z)6qLMZA zsrMv(VREKanmnJE^vUyO&|;vTld`vzqC8blMT+bxn2eZ(z0s$bQ7A2&yeNW6GDPvt zPB*a9_UD5r#|RX6xtF8d=#drw*ZlDM(Eb(w-^AB& zVH;pjpS&T=EekSWs)V*EPl%#|DJ>cRVYOT+>dgM)lA|#NQx0bk7)(IWyhid_fa{V( z^3t5DxE1WW0tacai*`A1Hp(T?lhXk9oJfaP}ECClkXUi8V&Fs;GjP( zv<`hin?rOu#3>&^AC!XJ>_K{JZc<*d13sG(7aJWxA_xRMKywf*hAhH{h^<3NL?d(O zU>jsNBRd65^l>Qea{n>R4eVRp(cCv(-#6XZb$lhV-aFFlJy`EO*yug9Hd0HRsPQMr z01#*0ICb#+*U-|2XU`KGyWWa-5$QS}8@p8>w7&t{RX?=b9=(f)(?0+Qg})NW`KSi` zn_B}ebOEY`XbNP)07|Vk1GQ|S#APqHdrUCc0!T`7+QFr>$q2z4fEY^UzzvQA%V1l> zyT<(%+QWtxEglo#L{UWs9kN=Aom$l11B7;Dv(W9J1lv75<1F)TW#xT#1f0D_znFZ}`(pos!Ed}wv4 zF_d13uMZt+4xO(Loo@`i4{ynR&1AZsOgEBmukiHkbUk^xk$e~45_=x~U~RCLI9lV6 zGBtIp9G;i>{yGOmo0@h#cHl6=9vzPzK!L*gx+x_1a$cK;s^YKRs@UNPBk~kLP)!KS_6MlIhm>Fuw~1j5BuEe$;R%Z|7E0#CSsjrFc!1WB z`CQ9q!(c(R;sKl-=o%(nT8^06OF2?OMC$4R2_9PHEhx+&TM08p6^Xfm;C~^ukR3b4 z!t^^kw+7l2q}A{-^w1}v0R7m{4INy)@w1PA`tj2n%@Y^vCoVPyFRjGZ6T6#<$$DaP zZQ)t4kvLc5&%HqpP6IEu9yo^yV%x1bK!?u;Y~OQZ+X>gCwtx*h@mwAk8diP|o#6sT z?gEDCvH*^I*BI7?mhEERg2+~ocQ)UJeb;Cft{9iikZt4Ve*Ulcrj72v z*Cs`1TA^?UF1;iIn`FkV*j;2~78cErwLoUbnxh#*_7>F!KqhT2E}}a^8#2345D-A$ z3k5adLx&oJ;}$57)DuV6Ezs0M9@{7tRVP7{%+8; zp-W9}4>)IzA!r5th(IUsw`>9123N9xd9@JRm%uqiucTXQz+t{!a55{k1**!B?~&#d z719&1+0;_e;IS=OUPd=p22RCNS|UVRc4R?gas#GlO$Laz=H^UJwDTTY>5tRH@!o8| z2-eCkbMJBbz>jpztjTy$o_$uW4nGuoyh8;1Nb(C7!C0R0%}MyJ-6 zo-WnSeXp_SdJW>{#J*-?x}KPBxZ;eiI&u~`{Lgj7!_r78xAAUexu}5z@G#B5KsRqd zk}sPr6Dhka=vj>@u5I0}2B2*dvuY0T1+{zuv%qZ|i_mQw+cgL~^|yj1O|@)l)=`yc zRI-^>M}iF&o$fCbS0N7VfmJudkV92i!o(qq)Jyam87+ADh zT)!l4I466DX$= z1amBdZT(yzwc{XQfbpLr077V6A)%Ys|EVc-M6m<3tu)D{MBgqha3yJAj8?M8kpLG1 zH5~~BmK0)1R5(4d3rhADJ^OTsKvZRtoh%dzuC>|vAYlb4$x$10fxV0GlULX~BJ@1kEH9(F(3bDmp>sA*e5AqC*4{#jm|txX&q4C!ZlZ z#j_KLx4}p7Y%rtllaK}xWSD-JORYeh;31%6$_&>O9~MdUXJJ_lg^nV!Veag0JC!yP z`~$N35DG}+9_OM1&B)=eB8PuBa_ZUO7pK0w*%-;(yYjCi?>ytb82|E0Ei=~`$*x>^ zF+A2BK3N|=*%&^xlK5R>sF^riPaIzRqi2am;&P3@%q$d~GR}hsF||-%!dI|R+@ibH zBC{>ZO)wBuPt{xT+=M<=U)5g?Ko9KS@qm@G_)yAVHB{w+(y?VN@(N!G*jA|8Q|Yj12vz!F1-(zKm8~W#NpQIR5@zrn^+NS7z=e;HD;Mbn9Ow|-d_sC0 zJX)rttUo?&h=%>i)_!k6Pm6B;{n&G-42KUKT|=~7tDhD zZ0ce%I?V;@rjv>&3QC~_7ev%BK`e)Bp$Kr9SUiNc!&-TM0j_Bakgm`%8vzvVAP~Gz zcPip40UhDKLGY)-1n*(UK`ahqfwzuk$chHckmaV6GuVc{DnT!Vyaz>U2Qx!qYPD#t zgc+Q0>zaWGNR4@NvLA~BSiFS=9SqYF1oH_5uQaIxM81Qy<5+wbixXJPU_mVxCP;8) zIv>l5QB2UZzU@SJSh0-+HQ=Zpg2ezUHpm?~)ZBlrzW-cf|4g0VyApozC+i2N>imAF zU3_%$WsplBYfhf2Po8;}YfN6K^M_aBkNDMbXxKIUQ2bO}{r=igW9Lb5x*|Jkkv;3l zohu)_IC7%SkFSIt?0Ph?ItHXW2Ajzv^(2dIlV>0W8A;Y6BkRclXzU+u_U)_p?OPpd z^rcpUNWQz7JWx*_Se>aS$5wdw0{07v^jC@WOCJ|YKDhkg^jFc*7kvYd23Jot4@}n& zOh28eA2?YXe5c-b>Sc)Q8F?A!26s0HChG%}Yq_Vn+SzN3f$yzEfmKiJe&SvN<}kK4 z_BDSLzKrZ`4j-)#AALIh?D7|C?Z(a8$N9$a;!5I0VzkDOQu}-r{&4|r`HFCRB#LHO zMAxdJ-E=|++rD+;2el_^l0gWuDU1vT$ zVl^Tb@Ly>_mzh^^;#>A+J_dCw%=FfMo&UiA?}=Rpy!EC|c0|j<<3n#G#e09|tz5bTpT9SMU%RJ0n)_`2@%(4nV{L7r uG5Ss;e(EmVMD{-neHwZg{WSXMFKUM`H+ny~8~e}TEuUxaD-KIKmH!6F#H*_S delta 1661 zcmY*ZO>7%Q6rQymuVZ_?an?VHon+SyNt&e8qz!H=s;Z(tlmKm|(o#i5%i>+z>pJW0 z?6^v*wc<(~jNp*6hvtCPNL&iJ1aaww12+(tY!}Ja3W*+iL0lrG5>gL{nN8!ulXmva zo9}&Z-kW)I|H_Y<*slx|23XF0vsi9N?#G6F|4`o_zCgv7hnI#aAOks4%PeP?vefQwQnTq1IlAQ|jsPr;%FJhAX>666j#+1XU%Y8ZcNB5CE-STZpF3bkB}FVl zRTd3Jg2jrP^ehNUswN`EK%$}7RAl|_J2@A`bbkEpeTTUp07MafBIso-0JdJ*n!p6r zKI?WSbLb{^X;MvJ(~szVH~^-v2Mm}2>mHx5jt+%_%|0_wjbN{AUl3~y)7f~uX;G_`$y7P(9?QKR?uCz;bTx6{ z^Xf1AhVi7y;w(~DF2f~*X4quQ6tQX?(TSh+7tiwph!{OFqh`z;*djH;(u|w&YL;Zf z52De18%~|U-q8aLMe5{5;KvZ(!?0%196Z1@Ig_jA2+w#z8$;v&N15~%&Y#FD9Fss3 z`!)colO*FcU?zIiNSKMMU^;O3tFVCOtW>KR0v-fZQ3OfXVO@u+gcPCF(27W{YleX8 zLPw+#pX<0!E6!!1VQIGG#ilPZnK@CV8F=a{BWWTDFt$VrTDPkx? zNv|0D=u42?oj=L7nNGkE{<%!ixFeR-nj-2Fs<{17kz{XfMPl!+!3f4c(iJYPLq%4L z$fb+*T4Px=T+YC8u;cz*ELN1_El6&i8&~c~wFb^gEJ-TV6vKf1gs6`!va*CEQLn2S z5^Jh~AgL&flBI!SO_4N_0KD}h#&pt+c)KI}QlJ}_OQMWxrYo`=R5iVU3^#}*SXPi# zjLk2)VXuQc3YT6Hz1DJrSi#-p(yIpstPn%!Q_^(BTFDKriX=fdR8k~dO=ZrymOeb` zMsdlZS`UqoU@|DaI{S z*Yk^u!ub`v72#|HVzx$2g88B$>J}e*d#%qqW|+rBDZ<#q$I}AOqktRz`eSEsw)I{& zccham*tvp}Td)ITt;O}=M(}wMq{cT3PI9Joshb?{Bv0GP(@ydXp0T4HcE)CB9QH`- zqBkqpY{6j{S{Jd`;MC^IcFKt_v=+N`d}oaw?y~WA|I?{CJMdcTo%Z$hk2XpeHy~`% zkJ1i1hdF3=hfa4{ZZ`;GsSY!0Gou?NhnZ+yz?Olb4$Iptze(Bbq1HugiKaV|nJ1B% zT|ek&*F){QPePfe47Yw`W2M9AZ9czE+5EAc#BrOMU)!ZYU;23rWM&@~whKFNec%jV z+6kn(+(d^vVRI)O?quuzXMtFIWh1qj`Ymu62MBQ;dh!W9`FnbP`^MvypGTbZ^>(zI zne1c=cBbHD7CvX6rH^-`iS^26^igHc?;8rYqu4V$)q}LzXDOkRP0iV+;h`W+#x@!v_JT`5&?qUsBYpm75$%{O6@6FSyqf7BzKjhCG?x>3s{}8$>^fhZmt! k_*fe?