Skip to content

Commit

Permalink
All tests should be passing now - search function tractable, needs re…
Browse files Browse the repository at this point in the history
…factoring
  • Loading branch information
charles-turner-1 committed Nov 11, 2024
1 parent 7564c71 commit 00d1551
Showing 1 changed file with 53 additions and 14 deletions.
67 changes: 53 additions & 14 deletions src/intake_dataframe_catalog/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,30 @@ def search(

if not query:
return df
if require_all and len(query.get(name_column, [""])) > 1:
return df.head(0)

lf: pl.LazyFrame = pl.from_pandas(df).lazy()
col_order = lf.columns
all_cols = lf.columns

# Keep the iterable columns and their dtypes hanging around for later
iterable_dtypes = {
colname: type(df[colname].iloc[0]) for colname in columns_with_iterables
}
iterable_qcols = [colname for colname in columns_with_iterables if colname in query]

filter_first = True
if require_all and not iterable_qcols:
# If we've specified require all but we don't have any iterable columns
# in the query, we shouldn't do anything. Previous behaviour seemed to
# promote the query columns to iterables at this point.
iterable_qcols = [colname for colname in query if colname in all_cols]
columns_with_iterables = [*columns_with_iterables, *iterable_qcols]
lf = lf.with_columns(
[pl.col(colname).cast(pl.List(pl.Utf8)) for colname in iterable_qcols]
)
filter_first = False

lf = lf.with_row_index()
for column in columns_with_iterables:
lf = lf.explode(column)
Expand All @@ -89,27 +103,49 @@ def search(
lf = lf.group_by("index").agg(
[
pl.col(col).implode().flatten().unique(maintain_order=True)
for col in col_order
for col in all_cols
]
)

lf = lf.drop("index").select(col_order)
lf = lf.drop("index").select(all_cols)
lf = lf.explode(name_column)

if require_all and iterable_qcols and not lf.collect().is_empty():
# Find rows where list.len() >= query.len(), and get all the names in those rows
nl = (
lf.filter(
[
pl.col(colname).list.len() >= len(query[colname])
for colname in iterable_qcols
]
if filter_first:
nl = (
lf.filter(
[
pl.col(colname).list.len() >= len(query[colname])
for colname in iterable_qcols
]
)
.select(name_column)
.collect()
.to_series()
)
.select(name_column)
.collect()
.to_series()
)
lf = lf.filter(pl.col(name_column).is_in(nl))
lf = lf.filter(pl.col(name_column).is_in(nl))
else:
nl = (
lf.group_by(name_column)
.agg(
[
pl.col(col).explode().flatten().unique(maintain_order=True)
for col in all_cols
if col != name_column
]
)
.filter(
[
pl.col(colname).list.len() >= len(query[colname])
for colname in iterable_qcols
]
)
.select(name_column)
.collect()
.to_series()
)
lf = lf.filter(pl.col(name_column).is_in(nl))

# Now we 'de-iterable' the non-iterable columns.
non_iter_cols = [
Expand All @@ -118,6 +154,9 @@ def search(
if col not in [*columns_with_iterables, name_column]
]
lf = lf.explode(non_iter_cols)
if not filter_first:
# We also need to 'de-iterable' the query columns
lf = lf.explode(iterable_qcols)

df = lf.collect().to_pandas()
for col, dtype in iterable_dtypes.items():
Expand Down

0 comments on commit 00d1551

Please sign in to comment.