Skip to content

Commit

Permalink
_search.search significantly cleaner
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-turner-1 committed Nov 11, 2024
1 parent 00d1551 commit e70b473
Showing 1 changed file with 48 additions and 61 deletions.
109 changes: 48 additions & 61 deletions src/intake_dataframe_catalog/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,24 @@ def search(
iterable_dtypes = {
colname: type(df[colname].iloc[0]) for colname in columns_with_iterables
}
iterable_qcols = [colname for colname in columns_with_iterables if colname in query]
columns_with_iterables = set(columns_with_iterables)
iterable_qcols = columns_with_iterables.intersection(query)
cols_to_deiter = set(all_cols).difference(columns_with_iterables, {name_column})

filter_first = True
if require_all and not iterable_qcols:
# If we've specified require all but we don't have any iterable columns
# in the query, we shouldn't do anything. Previous behaviour seemed to
# promote the query columns to iterables at this point.
iterable_qcols = [colname for colname in query if colname in all_cols]
columns_with_iterables = [*columns_with_iterables, *iterable_qcols]
# in the query we promote the query columns to iterables at this point.
group_on_names = True
iterable_qcols = set(query).intersection(all_cols)

lf = lf.with_columns(
[pl.col(colname).cast(pl.List(pl.Utf8)) for colname in iterable_qcols]
)
filter_first = False
# Keep track of the newly promoted columns & the need to de-iterable them later
columns_with_iterables.update(iterable_qcols)
cols_to_deiter.update(iterable_qcols)
else:
group_on_names = False

lf = lf.with_row_index()
for column in columns_with_iterables:
Expand All @@ -100,65 +105,47 @@ def search(
else:
lf = lf.filter(pl.col(colname).is_in(subquery))

lf = lf.group_by("index").agg(
[
pl.col(col).implode().flatten().unique(maintain_order=True)
for col in all_cols
]
lf = (
lf.group_by("index") # Piece the exploded columns back together
.agg(
[ # Re-aggregate the exploded columns into lists, flatten them out (imploding creates nested lists) and drop duplicates
pl.col(col).implode().flatten().unique(maintain_order=True)
for col in all_cols
]
)
.drop("index") # We don't need the index anymore
.explode(name_column) # Explode the name column back out so we can select on it
)

lf = lf.drop("index").select(all_cols)
lf = lf.explode(name_column)

if require_all and iterable_qcols and not lf.collect().is_empty():
# Find rows where list.len() >= query.len(), and get all the names in those rows
if filter_first:
nl = (
lf.filter(
[
pl.col(colname).list.len() >= len(query[colname])
for colname in iterable_qcols
]
)
.select(name_column)
.collect()
.to_series()
if require_all and iterable_qcols:
if group_on_names:
# Group by name_column and aggregate the other columns into lists
# first in this instance. Essentially the opposite of the previous
# group_by("index") operation.
nl_lf = lf.group_by(name_column).agg(
[
pl.col(col).explode().flatten().unique(maintain_order=True)
for col in (set(all_cols) - {name_column})
]
)
lf = lf.filter(pl.col(name_column).is_in(nl))
else:
nl = (
lf.group_by(name_column)
.agg(
[
pl.col(col).explode().flatten().unique(maintain_order=True)
for col in all_cols
if col != name_column
]
)
.filter(
[
pl.col(colname).list.len() >= len(query[colname])
for colname in iterable_qcols
]
)
.select(name_column)
.collect()
.to_series()
nl_lf = lf

nl = (
nl_lf.filter(
[
pl.col(colname).list.len() >= len(query[colname])
for colname in iterable_qcols
]
)
lf = lf.filter(pl.col(name_column).is_in(nl))

# Now we 'de-iterable' the non-iterable columns.
non_iter_cols = [
col
for col in lf.collect_schema().names()
if col not in [*columns_with_iterables, name_column]
]
lf = lf.explode(non_iter_cols)
if not filter_first:
# We also need to 'de-iterable' the query columns
lf = lf.explode(iterable_qcols)

df = lf.collect().to_pandas()
.select(name_column)
.collect()
.to_series()
)
lf = lf.filter(pl.col(name_column).is_in(nl))

df = lf.explode(list(cols_to_deiter)).collect().to_pandas()

for col, dtype in iterable_dtypes.items():
df[col] = df[col].apply(lambda x: dtype(x))

Expand Down

0 comments on commit e70b473

Please sign in to comment.