diff --git a/great_tables/_data_color/base.py b/great_tables/_data_color/base.py index 85ca8e7c7..ffc25e45a 100644 --- a/great_tables/_data_color/base.py +++ b/great_tables/_data_color/base.py @@ -3,7 +3,8 @@ from typing import TYPE_CHECKING import numpy as np -from great_tables._tbl_data import DataFrameLike, is_na +from great_tables._locations import resolve_cols_c, resolve_rows_i, RowSelectExpr +from great_tables._tbl_data import DataFrameLike, is_na, SelectExpr from great_tables.loc import body from great_tables.style import fill, text from typing_extensions import TypeAlias @@ -19,7 +20,8 @@ def data_color( self: GTSelf, - columns: str | list[str] | None = None, + columns: SelectExpr = None, + rows: RowSelectExpr = None, palette: str | list[str] | None = None, domain: list[str] | list[int] | list[float] | None = None, na_color: str | None = None, @@ -47,6 +49,10 @@ def data_color( columns The columns to target. Can either be a single column name or a series of column names provided in a list. + rows + In conjunction with `columns=`, we can specify which rows should be colored. By default, + all rows in the targeted columns will be colored. Alternatively, we can provide a list + of row indices. palette The color palette to use. This should be a list of colors (e.g., `["#FF0000", "#00FF00", "#0000FF"]`). A ColorBrewer palette could also be used, just supply the name (reference @@ -202,18 +208,20 @@ def data_color( # get a list of all columns in the table body columns_resolved: list[str] - if isinstance(columns, str): - columns_resolved = [columns] - elif columns is None: + if columns is None: columns_resolved = data_table.columns else: - columns_resolved = columns + columns_resolved = resolve_cols_c(data=self, expr=columns) + + row_res = resolve_rows_i(self, rows) + row_pos = [name_pos[1] for name_pos in row_res] gt_obj = self # For each column targeted, get the data values as a new list object for col in columns_resolved: - column_vals = data_table[col].to_list() + # This line handles both pandas and polars dataframes + column_vals = data_table[col][row_pos].to_list() # Filter out NA values from `column_vals` filtered_column_vals = [x for x in column_vals if not is_na(data_table, x)] @@ -260,7 +268,7 @@ def data_color( # for every color value in color_vals, apply a fill to the corresponding cell # by using `tab_style()` - for i, color_val in enumerate(color_vals): + for i, color_val in zip(row_pos, color_vals): if autocolor_text: fgnd_color = _ideal_fgnd_color(bgnd_color=color_val) diff --git a/tests/data_color/__snapshots__/test_data_color.ambr b/tests/data_color/__snapshots__/test_data_color.ambr index df6e233ab..620ada1a0 100644 --- a/tests/data_color/__snapshots__/test_data_color.ambr +++ b/tests/data_color/__snapshots__/test_data_color.ambr @@ -362,6 +362,66 @@ ''' # --- +# name: test_data_color_pd_cols_rows_snap + ''' +
+