diff --git a/src/ia_collection_analyzer/streamlit.py b/src/ia_collection_analyzer/streamlit.py index cd1314fc..5da7cf7b 100644 --- a/src/ia_collection_analyzer/streamlit.py +++ b/src/ia_collection_analyzer/streamlit.py @@ -38,9 +38,10 @@ if "original_values" not in st.session_state: st.session_state.original_values = {} + @st.fragment def collection_input(): - """Fragment for collection ID input and metadata fetching""" + """Fragment for collection ID input and metadata fetching""" # input the collection name col1, col2 = st.columns([6, 1], vertical_alignment="bottom") with col1: @@ -59,7 +60,10 @@ def collection_input(): if not conform_button and not st.session_state.got_metadata or collection_id == "": st.stop() - if st.session_state.got_metadata and collection_id == st.session_state.collection_id: + if ( + st.session_state.got_metadata + and collection_id == st.session_state.collection_id + ): items_pd = st.session_state.items_pd # progress_message progress_message = st.session_state.progress_message @@ -71,14 +75,15 @@ def collection_input(): st.markdown("Failed to display top 10 lines. Only first will be shown.") st.write(items_pd.head(1)) st.write(e) - + return # Check if we need to fetch new data - if not st.session_state.got_metadata or collection_id != st.session_state.collection_id: - st.markdown( - f"Getting fresh metadata for collection: **{collection_id}**" - ) + if ( + not st.session_state.got_metadata + or collection_id != st.session_state.collection_id + ): + st.markdown(f"Getting fresh metadata for collection: **{collection_id}**") items, progress_message = fetch_metadata(collection_id) data_transform_text = st.text("Transforming data...") items_pd = pd.DataFrame(items) @@ -117,29 +122,32 @@ def collection_input(): # Update cache st.session_state.items_pd = items_pd else: - st.markdown( - f"Using cached metadata for collection: **{collection_id}**" - ) + st.markdown(f"Using cached metadata for collection: **{collection_id}**") items_pd = st.session_state.items_pd st.session_state.got_metadata = True st.session_state.collection_id = collection_id st.session_state.progress_message = progress_message st.session_state.selected_columns = [] - + st.rerun() + @st.fragment def column_selector(): """Fragment for selecting columns to analyze""" items_pd = st.session_state.items_pd - + st.header("Selecting columns to analyze") st.write("Select additional columns you want to analyze:") - seleactable_columns = [col for col in items_pd.columns if col not in REQUIRED_METADATA] + seleactable_columns = [ + col for col in items_pd.columns if col not in REQUIRED_METADATA + ] col1, col2 = st.columns([6, 1], vertical_alignment="bottom") - selected_columns = st.multiselect("Select columns:", seleactable_columns, default=[]) + selected_columns = st.multiselect( + "Select columns:", seleactable_columns, default=[] + ) # Update the filtering code to use cache if ( @@ -159,6 +167,7 @@ def column_selector(): st.write("Preview of the selected columns:") st.write(filtered_pd.head(30)) + @st.fragment def transform_data(): """Fragment for transforming data""" @@ -168,13 +177,12 @@ def transform_data(): index=0, placeholder="No", ) - + if transform_needed == "No": return filtered_pd = st.session_state.filtered_pd - - + st.header("Transform Column") st.write("Transform an existing column with data transformations") @@ -360,18 +368,19 @@ def transform_data(): {"source_col": source_col, "transform_type": transform_type} ) st.session_state.original_values[source_col] = preview_df["Original"] - + st.rerun() + @st.fragment def plot_data(): """Fragment for data visualization""" if not st.session_state.filtered_pd is not None: return - + filtered_pd = st.session_state.filtered_pd plotable_columns = st.session_state.selected_columns + REQUIRED_METADATA - + col1, col2, col3 = st.columns([3, 3, 1], vertical_alignment="bottom") with col1: x_axis = st.selectbox("Select the x-axis:", plotable_columns, index=0) @@ -442,6 +451,7 @@ def plot_data(): counts_df = pd.crosstab(expanded_df[x_axis], expanded_df[y_axis]) st.write(counts_df) + def main(): collection_input() if st.session_state.got_metadata: @@ -450,5 +460,6 @@ def main(): transform_data() plot_data() + if __name__ == "__main__": - main() \ No newline at end of file + main()