Skip to content

Commit

Permalink
Fix parallelisation on find_applicable
Browse files Browse the repository at this point in the history
  • Loading branch information
Sacha Beniamine committed Jan 15, 2025
1 parent d77fac8 commit ebf3318
Showing 1 changed file with 12 additions and 22 deletions.
34 changes: 12 additions & 22 deletions src/qumin/representations/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,6 @@ def find_patterns(self, paradigms, *args, method="edits", disable_tqdm=False, cp
list(self)), # list of dict keys = the pairs
total=comb(len(self.cells), 2)))


def __repr__(self):
if len(self.cells) == 0:
return "ParadigmPatterns(empty)"
Expand Down Expand Up @@ -1223,20 +1222,16 @@ def find_cellpair_applicable(self, pair):
Dataframe of applicable patterns.
"""

def _iter_applicable_patterns(row, pair):
cell_x = pair[0]
for pattern in available_patterns:
if pattern.applicable(row.form_x, cell_x):
yield pattern
df = self[pair]
available_patterns = self.unique_patterns(pair)
cell_x = pair[0]

def applicable(*args):
"""Returns all applicable patterns to a single row"""
return tuple(_iter_applicable_patterns(*args))
def applicable(form):
""" Return a tuple of all applicable patterns for a given form"""
return tuple((p for p in available_patterns if p.applicable(form, cell_x)))

available_patterns = self.unique_patterns(pair)
df = self[pair]
has_pat = df.pattern.notna()
return df.loc[has_pat,:].apply(applicable, axis=1)
has_pat = df['pattern'].notna()
return (pair, df.loc[has_pat, "form_x"].apply(applicable))

def unique_patterns(self, pair):
""" Get a unique sequence of available patterns for a pair of cells.
Expand Down Expand Up @@ -1270,15 +1265,10 @@ def find_applicable(self, disable_tqdm=False, cpus=1, **kwargs):
to_add[key[::-1]] = self[key].rename(columns=col_rename)
self.update(to_add)

log.info("total cpus: " + str(cpus))
# Compute
with Pool(cpus) as pool: # Create a multiprocessing Pool
applicables = tqdm(pool.imap_unordered(self.find_cellpair_applicable,
list(self)), # list of dict keys = the pairs
total=comb(len(self.cells), 2))

# Update self
for pair, res in applicables:
df = self[pair]
df.loc[res.index, "applicable"] = res


for pair, res in tqdm(pool.imap_unordered(self.find_cellpair_applicable, self), total=len(self)):
df = self[pair]
df.loc[res.index, "applicable"] = res

0 comments on commit ebf3318

Please sign in to comment.