Skip to content

Commit

Permalink
fix social bias plots script pyright errs
Browse files Browse the repository at this point in the history
  • Loading branch information
hunarbatra committed Feb 4, 2024
1 parent bc3716d commit e6eef21
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions scripts/social_biases/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,23 +530,23 @@ def discrim_eval_plot(
tasks_frequency_standard_error = task_age_frequency_standard_error

discrimination_score = (
tasks_mean - baseline_mean if task != "discrim_eval_age" else tasks_mean - baseline_age_mean
) # type: ignore
tasks_mean - baseline_mean if task != "discrim_eval_age" else tasks_mean - baseline_age_mean # type: ignore
)
discrimination_score_frequency = (
tasks_sum - baseline_sum if task != "discrim_eval_age" else tasks_sum - baseline_age_sum
) # type: ignore
tasks_sum - baseline_sum if task != "discrim_eval_age" else tasks_sum - baseline_age_sum # type: ignore
)
discrimination_score_standard_error = (
np.sqrt(tasks_standard_error**2 + baseline_standard_error**2)
if task != "discrim_eval_age"
else np.sqrt(task_age_standard_error**2 + baseline_age_standard_error**2)
)

tasks_log_odds = np.log((tasks_mean) / (1 - tasks_mean))
tasks_log_odds = np.log((tasks_mean) / (1 - tasks_mean)) # type: ignore
baseline_log_odds = (
np.log((baseline_mean) / (1 - baseline_mean))
np.log((baseline_mean) / (1 - baseline_mean)) # type: ignore
if task != "discrim_eval_age"
else np.log((baseline_age_mean) / (1 - baseline_age_mean))
) # type: ignore
else np.log((baseline_age_mean) / (1 - baseline_age_mean)) # type: ignore
)
logodds_discrimination_score = tasks_log_odds - baseline_log_odds # type: ignore
tasks_log_odds_standard_error = np.sqrt(1 / (tasks_count * tasks_mean * (1 - tasks_mean))) * 1.96 # type: ignore

Expand Down Expand Up @@ -731,7 +731,7 @@ def get_intervention_name(intervention_name: str) -> str:
se_pueb_list = []

for model_name, model_data in combined_df.groupby("model"):
PUO, SE_PUO, PUEB, SE_PUEB = compute_BBQ_combined_classification(model_data)
PUO, SE_PUO, PUEB, SE_PUEB = compute_BBQ_combined_classification(model_data) # type: ignore

puo_list.append(PUO)
pueb_list.append(PUEB)
Expand Down

0 comments on commit e6eef21

Please sign in to comment.