diff --git a/great_tables/_data_color/base.py b/great_tables/_data_color/base.py index 85ca8e7c7..ffc25e45a 100644 --- a/great_tables/_data_color/base.py +++ b/great_tables/_data_color/base.py @@ -3,7 +3,8 @@ from typing import TYPE_CHECKING import numpy as np -from great_tables._tbl_data import DataFrameLike, is_na +from great_tables._locations import resolve_cols_c, resolve_rows_i, RowSelectExpr +from great_tables._tbl_data import DataFrameLike, is_na, SelectExpr from great_tables.loc import body from great_tables.style import fill, text from typing_extensions import TypeAlias @@ -19,7 +20,8 @@ def data_color( self: GTSelf, - columns: str | list[str] | None = None, + columns: SelectExpr = None, + rows: RowSelectExpr = None, palette: str | list[str] | None = None, domain: list[str] | list[int] | list[float] | None = None, na_color: str | None = None, @@ -47,6 +49,10 @@ def data_color( columns The columns to target. Can either be a single column name or a series of column names provided in a list. + rows + In conjunction with `columns=`, we can specify which rows should be colored. By default, + all rows in the targeted columns will be colored. Alternatively, we can provide a list + of row indices. palette The color palette to use. This should be a list of colors (e.g., `["#FF0000", "#00FF00", "#0000FF"]`). A ColorBrewer palette could also be used, just supply the name (reference @@ -202,18 +208,20 @@ def data_color( # get a list of all columns in the table body columns_resolved: list[str] - if isinstance(columns, str): - columns_resolved = [columns] - elif columns is None: + if columns is None: columns_resolved = data_table.columns else: - columns_resolved = columns + columns_resolved = resolve_cols_c(data=self, expr=columns) + + row_res = resolve_rows_i(self, rows) + row_pos = [name_pos[1] for name_pos in row_res] gt_obj = self # For each column targeted, get the data values as a new list object for col in columns_resolved: - column_vals = data_table[col].to_list() + # This line handles both pandas and polars dataframes + column_vals = data_table[col][row_pos].to_list() # Filter out NA values from `column_vals` filtered_column_vals = [x for x in column_vals if not is_na(data_table, x)] @@ -260,7 +268,7 @@ def data_color( # for every color value in color_vals, apply a fill to the corresponding cell # by using `tab_style()` - for i, color_val in enumerate(color_vals): + for i, color_val in zip(row_pos, color_vals): if autocolor_text: fgnd_color = _ideal_fgnd_color(bgnd_color=color_val) diff --git a/tests/data_color/__snapshots__/test_data_color.ambr b/tests/data_color/__snapshots__/test_data_color.ambr index df6e233ab..620ada1a0 100644 --- a/tests/data_color/__snapshots__/test_data_color.ambr +++ b/tests/data_color/__snapshots__/test_data_color.ambr @@ -362,6 +362,66 @@ ''' # --- +# name: test_data_color_pd_cols_rows_snap + ''' + + + 1 + 51 + + + 2 + 52 + + + 3 + 53 + + + 4 + 54 + + + 5 + 55 + + + 200 + 200 + + + ''' +# --- +# name: test_data_color_pl_cols_rows_snap + ''' + + + 1 + 51 + + + 2 + 52 + + + 3 + 53 + + + 4 + 54 + + + 5 + 55 + + + 200 + 200 + + + ''' +# --- # name: test_data_color_simple_df_snap ''' diff --git a/tests/data_color/test_data_color.py b/tests/data_color/test_data_color.py index 186deaa5c..eb01334ce 100644 --- a/tests/data_color/test_data_color.py +++ b/tests/data_color/test_data_color.py @@ -55,6 +55,28 @@ def test_data_color_simple_exibble_snap(snapshot: str, df: DataFrameLike): assert_rendered_body(snapshot, gt) +def test_data_color_pd_cols_rows_snap(snapshot: str): + df = pd.DataFrame({"a": [1, 2, 3, 4, 5, 200], "b": [51, 52, 53, 54, 55, 200]}) + new_gt = GT(df).data_color(columns=["a"], rows=[0, 1, 2, 3, 4]) + assert_rendered_body(snapshot, new_gt) + new_gt2 = GT(df).data_color(columns=["a"], rows=lambda df_: df_["a"].lt(60)) + assert create_body_component_h(new_gt._build_data("html")) == create_body_component_h( + new_gt2._build_data("html") + ) + + +def test_data_color_pl_cols_rows_snap(snapshot: str): + import polars.selectors as cs + + df = pl.DataFrame({"a": [1, 2, 3, 4, 5, 200], "b": [51, 52, 53, 54, 55, 200]}) + new_gt = GT(df).data_color(columns=["b"], rows=[0, 1, 2, 3, 4]) + assert_rendered_body(snapshot, new_gt) + new_gt2 = GT(df).data_color(columns=cs.starts_with("b"), rows=pl.col("b").lt(60)) + assert create_body_component_h(new_gt._build_data("html")) == create_body_component_h( + new_gt2._build_data("html") + ) + + @pytest.mark.parametrize("none_val", [None, np.nan, float("nan"), pd.NA]) @pytest.mark.parametrize("df_cls", [pd.DataFrame, pl.DataFrame]) def test_data_color_missing_value(df_cls, none_val):