Skip to content

Commit

Permalink
Refactored and modularised update_ui function.
Browse files Browse the repository at this point in the history
  • Loading branch information
pineapple-cat committed Mar 8, 2024
1 parent 24db419 commit 11863ad
Showing 1 changed file with 140 additions and 89 deletions.
229 changes: 140 additions & 89 deletions post-processing/streamlit_post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ def update_ui(post: PostProcessing, config: ConfigHandler):
"""

# stop the session state from resetting each time this function is run
if st.session_state.get("post") is None:
st.session_state.post = post
st.session_state.config = config
state = st.session_state
if state.get("post") is None:
state.post = post
state.config = config

post = st.session_state.post
config = st.session_state.config
Expand All @@ -40,6 +41,7 @@ def update_ui(post: PostProcessing, config: ConfigHandler):
if show_df:
st.dataframe(post.df[post.mask][config.plot_columns], hide_index=True, use_container_width=True)

# config file uploader
st.divider()
st.file_uploader("Upload Config", type="yaml", key="uploaded_config", on_change=update_config)

Expand All @@ -48,75 +50,16 @@ def update_ui(post: PostProcessing, config: ConfigHandler):
if title != config.title:
config.title = title

st.write("#### Axis Options")
with st.form(key="axis options"):
axis_select("x", config.x_axis)
sort = st.checkbox("sort descending", True if config.x_axis.get("sort") == "descending" else False)
axis_select("y", config.y_axis)
submit = st.form_submit_button("Update Axes")
if submit:
update_axes()
config.x_axis["sort"] = "descending" if sort else "ascending"

st.write("#### Filter Options")
# allow wide multiselect labels
st.markdown(
"""
<style>
.stMultiSelect [data-baseweb=select] span{
max-width: inherit;
}
</style>""",
unsafe_allow_html=True)

st.write("###### Current AND Filters")
st.multiselect("AND Filters", config.and_filters if config.and_filters else [None],
config.and_filters, key="and", on_change=update_filter,
args=["and"], placeholder="None", label_visibility="collapsed")

st.write("###### Current OR Filters")
st.multiselect("OR Filters", config.or_filters if config.or_filters else [None],
config.or_filters, key="or", on_change=update_filter,
args=["or"], placeholder="None", label_visibility="collapsed")

st.write("###### Current Series")
st.multiselect("Series", config.series if config.series else [None],
config.series, key="series", on_change=update_filter,
args=["series"], placeholder="None", label_visibility="collapsed")

# update types after changing axes, filters, and series
update_types()

# FIXME: can/should this section be placed before displaying current filters by caching earlier components?
st.write("###### Add New Filter")
with st.container(border=True):
c1, c2 = st.columns(2)
with c1:
st.selectbox("filter type", filter_types, key="filter_type")
with c2:
st.selectbox("column type", column_types, key="column_type")

c1, c2, c3 = st.columns(3)
with c1:
st.selectbox("filter column", post.df.columns, key="filter_col")
with c2:
if st.session_state.filter_type == "series":
st.selectbox("operator", ["=="], key="filter_op")
else:
st.selectbox("operator", operators, key="filter_op")
# FIXME: user should be allowed to select values that aren't in the column as well
with c3:
filter_col = post.df[st.session_state.filter_col].drop_duplicates()
st.selectbox("filter value", filter_col.sort_values(), key="filter_val")

current_filter = [st.session_state.filter_col,
st.session_state.filter_op,
st.session_state.filter_val]
st.button("Add Filter", on_click=add_filter, args=[current_filter])
# display axis options
axis_options()
# display filter options
filter_options()

generate_graph, download_config = st.columns(2)
# re-run post processing and create a new plot
with generate_graph:
st.button("Generate Graph", on_click=rerun_post_processing, use_container_width=True)
# download session state config
with download_config:
st.download_button("Download Config", config.to_yaml(),
"{0}_config.yaml".format((config.title).lower().replace(" ", "_")),
Expand All @@ -133,6 +76,27 @@ def update_config():
st.session_state.config = ConfigHandler(load_config(uploaded_config))


def axis_options():
"""
Display axis options interface.
"""

config = st.session_state.config
st.write("#### Axis Options")
with st.form(key="axis_options"):

# x-axis select
axis_select("x", config.x_axis)
sort = st.checkbox("sort descending", True if config.x_axis.get("sort") == "descending" else False)
# y-axis select
axis_select("y", config.y_axis)
submit = st.form_submit_button("Update Axes")

if submit:
update_axes()
config.x_axis["sort"] = "descending" if sort else "ascending"


def axis_select(label: str, axis: dict):
"""
Allow the user to select axis column and type for post-processing.
Expand Down Expand Up @@ -192,20 +156,21 @@ def update_axes():
"""

# FIXME (issue #271): if both axis columns are the same, this results in incorrect behaviour
config = st.session_state.config
x_column = st.session_state.x_axis_column
y_column = st.session_state.y_axis_column
x_units_column = st.session_state.x_axis_units_column
x_units_custom = st.session_state.x_axis_units_custom
y_units_column = st.session_state.y_axis_units_column
y_units_custom = st.session_state.y_axis_units_custom
state = st.session_state
config = state.config
x_column = state.x_axis_column
y_column = state.y_axis_column
x_units_column = state.x_axis_units_column
x_units_custom = state.x_axis_units_custom
y_units_column = state.y_axis_units_column
y_units_custom = state.y_axis_units_custom

# update columns
config.x_axis["value"] = x_column
config.y_axis["value"] = y_column
# update column types
config.column_types[x_column] = st.session_state.x_axis_type
config.column_types[y_column] = st.session_state.y_axis_type
config.column_types[x_column] = state.x_axis_type
config.column_types[y_column] = state.y_axis_type

# update units
# NOTE: units are automatically interpreted as strings for simplicity
Expand All @@ -220,6 +185,9 @@ def update_axes():
config.y_axis["units"] = {"column": y_units_column}
config.column_types[y_units_column] = "str"

# update types after changing axes
update_types()


def update_types():
"""
Expand All @@ -237,24 +205,106 @@ def update_types():
post.apply_df_types(config.all_columns, config.column_types)


def update_filter(key):
def filter_options():
"""
Display filter options interface.
"""

st.write("#### Filter Options")
# allow wide multiselect labels
st.markdown(
"""
<style>
.stMultiSelect [data-baseweb=select] span{
max-width: inherit;
}
</style>""",
unsafe_allow_html=True)

# display current filters
current_filters()
# display new filter addition options
new_filter_options()


def current_filters():
"""
Display current filters.
"""

config = st.session_state.config
st.write("###### Current AND Filters")
st.multiselect("AND Filters", config.and_filters if config.and_filters else [None],
config.and_filters, key="and", on_change=update_filter,
args=["and"], placeholder="None", label_visibility="collapsed")

st.write("###### Current OR Filters")
st.multiselect("OR Filters", config.or_filters if config.or_filters else [None],
config.or_filters, key="or", on_change=update_filter,
args=["or"], placeholder="None", label_visibility="collapsed")

st.write("###### Current Series")
st.multiselect("Series", config.series if config.series else [None],
config.series, key="series", on_change=update_filter,
args=["series"], placeholder="None", label_visibility="collapsed")


def new_filter_options():
"""
Display new filter addition options interface.
"""

state = st.session_state
post = state.post
# FIXME: can/should this section be placed before displaying current filters by caching earlier components?
st.write("###### Add New Filter")
with st.container(border=True):

c1, c2 = st.columns(2)
with c1:
st.selectbox("filter type", filter_types, key="filter_type")
with c2:
st.selectbox("column type", column_types, key="column_type")

c1, c2, c3 = st.columns(3)
with c1:
st.selectbox("filter column", post.df.columns, key="filter_col")
with c2:
if state.filter_type == "series":
st.selectbox("operator", ["=="], key="filter_op")
else:
st.selectbox("operator", operators, key="filter_op")
# FIXME: user should be allowed to select values that aren't in the column as well
with c3:
filter_col = post.df[state.filter_col].drop_duplicates()
st.selectbox("filter value", filter_col.sort_values(), key="filter_val")

current_filter = [state.filter_col, state.filter_op, state.filter_val]
st.button("Add Filter", on_click=add_filter, args=[current_filter])


def update_filter(key: str):
"""
Apply user-selected filters or series to session state config.
Args:
key: string, type of filter to update.
"""

config = st.session_state.config
state = st.session_state
config = state.config
if key == "series":
setattr(config, key, st.session_state[key])
setattr(config, key, state[key])
else:
config.filters[key] = st.session_state[key]
config.filters[key] = state[key]

# re-parse filters
config.parse_filters()
# update types after changing filters or series
update_types()


def add_filter(filter):
def add_filter(filter: list):
"""
Allow the user to add a new filter or series to session state config.
Expand All @@ -263,17 +313,18 @@ def add_filter(filter):
"""

# FIXME: there is a problem with filter datetime/timestamp formatting that requires further investigation
key = st.session_state.filter_type
loc = st.session_state[key]
if filter not in loc:
state = st.session_state
key = state.filter_type
if filter not in state[key]:
# remove operator from series
if key == "series":
del filter[1]
# add filter to appropriate list location and config
loc.append(filter)
update_filter(key)
# add filter to appropriate filter list
state[key].append(filter)
# update column type
st.session_state.config.column_types[filter[0]] = st.session_state.column_type
state.config.column_types[filter[0]] = state.column_type
# add filter to config and update df types
update_filter(key)


def rerun_post_processing():
Expand Down

0 comments on commit 11863ad

Please sign in to comment.