diff --git a/src/intake_dataframe_catalog/_search.py b/src/intake_dataframe_catalog/_search.py index da7b0fa..b9dbc5c 100644 --- a/src/intake_dataframe_catalog/_search.py +++ b/src/intake_dataframe_catalog/_search.py @@ -65,9 +65,11 @@ def search( if not query: return df + if require_all and len(query.get(name_column, [""])) > 1: + return df.head(0) lf: pl.LazyFrame = pl.from_pandas(df).lazy() - col_order = lf.columns + all_cols = lf.columns # Keep the iterable columns and their dtypes hanging around for later iterable_dtypes = { @@ -75,6 +77,18 @@ def search( } iterable_qcols = [colname for colname in columns_with_iterables if colname in query] + filter_first = True + if require_all and not iterable_qcols: + # If we've specified require all but we don't have any iterable columns + # in the query, we shouldn't do anything. Previous behaviour seemed to + # promote the query columns to iterables at this point. + iterable_qcols = [colname for colname in query if colname in all_cols] + columns_with_iterables = [*columns_with_iterables, *iterable_qcols] + lf = lf.with_columns( + [pl.col(colname).cast(pl.List(pl.Utf8)) for colname in iterable_qcols] + ) + filter_first = False + lf = lf.with_row_index() for column in columns_with_iterables: lf = lf.explode(column) @@ -89,27 +103,49 @@ def search( lf = lf.group_by("index").agg( [ pl.col(col).implode().flatten().unique(maintain_order=True) - for col in col_order + for col in all_cols ] ) - lf = lf.drop("index").select(col_order) + lf = lf.drop("index").select(all_cols) lf = lf.explode(name_column) if require_all and iterable_qcols and not lf.collect().is_empty(): # Find rows where list.len() >= query.len(), and get all the names in those rows - nl = ( - lf.filter( - [ - pl.col(colname).list.len() >= len(query[colname]) - for colname in iterable_qcols - ] + if filter_first: + nl = ( + lf.filter( + [ + pl.col(colname).list.len() >= len(query[colname]) + for colname in iterable_qcols + ] + ) + .select(name_column) + .collect() + .to_series() ) - .select(name_column) - .collect() - .to_series() - ) - lf = lf.filter(pl.col(name_column).is_in(nl)) + lf = lf.filter(pl.col(name_column).is_in(nl)) + else: + nl = ( + lf.group_by(name_column) + .agg( + [ + pl.col(col).explode().flatten().unique(maintain_order=True) + for col in all_cols + if col != name_column + ] + ) + .filter( + [ + pl.col(colname).list.len() >= len(query[colname]) + for colname in iterable_qcols + ] + ) + .select(name_column) + .collect() + .to_series() + ) + lf = lf.filter(pl.col(name_column).is_in(nl)) # Now we 'de-iterable' the non-iterable columns. non_iter_cols = [ @@ -118,6 +154,9 @@ def search( if col not in [*columns_with_iterables, name_column] ] lf = lf.explode(non_iter_cols) + if not filter_first: + # We also need to 'de-iterable' the query columns + lf = lf.explode(iterable_qcols) df = lf.collect().to_pandas() for col, dtype in iterable_dtypes.items():