-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathanova_in_spark.py
49 lines (42 loc) · 2.38 KB
/
anova_in_spark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from pyspark.sql.functions import lit, avg, count, udf, struct, sum
from pyspark.sql.types import DoubleType
def one_way_anova(df, categorical_var, continuous_var):
"""
Given a Spark Dataframe, compute the one-way ANOVA using the given categorical and continuous variables.
:param df: Spark Dataframe
:param categorical_var: Name of the column that represents the grouping variable to use
:param continuous_var: Name of the column corresponding the continuous variable to analyse
:return: Sum of squares within groups, Sum of squares between groups, F-statistic, degrees of freedom 1, degrees of freedom 2
"""
global_avg = df.select(avg(continuous_var)).take(1)[0][0]
avg_in_groups = df.groupby(categorical_var).agg(avg(continuous_var).alias("Group_avg"),
count("*").alias("N_of_records_per_group"))
avg_in_groups = avg_in_groups.withColumn("Global_avg",
lit(global_avg))
udf_between_ss = udf(lambda x: x[0] * (x[1] - x[2]) ** 2,
DoubleType())
between_df = avg_in_groups.withColumn("squared_diff",
udf_between_ss(struct('N_of_records_per_group',
'Global_avg',
'Group_avg')))
ssbg = between_df.select(sum('squared_diff')).take(1)[0][0]
within_df_joined = avg_in_groups \
.join(df,
df[categorical_var] == avg_in_groups[categorical_var]) \
.drop(avg_in_groups[categorical_var])
udf_within_ss = udf(lambda x: (x[0] - x[1]) ** 2, DoubleType())
within_df_joined = within_df_joined.withColumn("squared_diff",
udf_within_ss(struct(continuous_var,
'Group_avg')))
sswg = within_df_joined \
.groupby(categorical_var) \
.agg(sum("squared_diff").alias("sum_of_squares_within_gropus")) \
.select(sum('sum_of_squares_within_gropus')).take(1)[0][0]
m = df.groupby(categorical_var) \
.agg(count("*")) \
.count() # number of levels
n = df.count() # number of observations
df1 = m - 1
df2 = n - m
f_statistic = (ssbg / df1) / (sswg / df2)
return sswg, ssbg, f_statistic, df1, df2