Skip to content

Commit

Permalink
feat: enhance data visualization for numeric and categorical y-axis v…
Browse files Browse the repository at this point in the history
…alues
  • Loading branch information
Ovler-Young committed Nov 21, 2024
1 parent 8c3ee28 commit 3da7dba
Showing 1 changed file with 57 additions and 25 deletions.
82 changes: 57 additions & 25 deletions src/ia_collection_analyzer/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,30 +140,62 @@
if plot_button and x_axis != y_axis:
st.write("Plotting the data...")
st.write(f"X-axis: {x_axis}, Y-axis: {y_axis}")

# Create comprehensive aggregation table
all_metrics = (
filtered_pd.groupby(x_axis)[y_axis]
.agg(
[
("Count", "count"),
("Sum", "sum"),
("Mean", "mean"),
("Median", "median"),
("Min", "min"),
("Max", "max"),
]

# if y_axis is hashable , plot
if isinstance(filtered_pd[y_axis].iloc[0], (int, float, np.int64, np.float64)):
# Create comprehensive aggregation table
all_metrics = (
filtered_pd.groupby(x_axis)[y_axis]
.agg(
[
("Count", "count"),
("Sum", "sum"),
("Mean", "mean"),
("Median", "median"),
("Min", "min"),
("Max", "max"),
]
)
.reset_index()
)
.reset_index()
)

# Display complete aggregated data
st.write("Complete aggregation metrics:")
# Create multi-line chart (excluding Count since it's often on different scale)
metrics_for_plot = all_metrics.drop(columns=["Count", "Sum", "Max"])
metrics_for_plot = metrics_for_plot.set_index(x_axis)

st.write("Multi-metric trend lines:")
st.line_chart(metrics_for_plot)

st.write(all_metrics)
# Display complete aggregated data
st.write("Complete aggregation metrics:")
# Create multi-line chart (excluding Count since it's often on different scale)
metrics_for_plot = all_metrics.drop(columns=["Count", "Sum", "Max"])
metrics_for_plot = metrics_for_plot.set_index(x_axis)

st.write("Multi-metric trend lines:")
st.line_chart(metrics_for_plot)

st.write(all_metrics)

# if y_axis is not numeric, count and plot
else:
st.write("Analyzing distribution across categories...")

# Create mask for list and non-list values
is_list_mask = filtered_pd[y_axis].apply(lambda x: isinstance(x, list))

# Handle list values
list_data = filtered_pd[is_list_mask][[x_axis, y_axis]].copy()
exploded_list = list_data.explode(y_axis)

# Handle non-list values
non_list_data = filtered_pd[~is_list_mask][[x_axis, y_axis]]

# Combine results efficiently
expanded_df = pd.concat([exploded_list, non_list_data])

# Create pivot table and plot
pivot_table = pd.crosstab(
expanded_df[x_axis],
expanded_df[y_axis],
normalize='index'
) * 100

st.bar_chart(pivot_table)

st.write("Distribution counts:")
counts_df = pd.crosstab(expanded_df[x_axis], expanded_df[y_axis])
st.write(counts_df)

0 comments on commit 3da7dba

Please sign in to comment.