Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated the strategy design to separate traversal from node selection #152

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 35 additions & 32 deletions packages/graph-retriever/src/graph_retriever/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,28 @@
from graph_retriever.types import Node


class NodeTracker:
"""Helper class for tracking traversal progress."""

def __init__(self) -> None:
self.visited_ids: set[str] = set()
self.to_traverse: dict[str, Node] = {}
self.selected: dict[str, Node] = {}

def select(self, nodes: dict[str, Node]) -> None:
"""Select nodes to be included in the result set."""
self.selected.update(nodes)

def traverse(self, nodes: dict[str, Node]) -> None:
"""Select nodes to be included in the next traversal."""
self.to_traverse.update(nodes)

def select_and_traverse(self, nodes: dict[str, Node]) -> None:
"""Select nodes to be included in the result set and the next traversal."""
self.select(nodes=nodes)
self.traverse(nodes=nodes)


@dataclasses.dataclass(kw_only=True)
class Strategy(abc.ABC):
"""
Expand All @@ -22,30 +44,34 @@ class Strategy(abc.ABC):
Parameters
----------
k :
Maximum number of nodes to retrieve during traversal.
Maximum number of nodes to select and return during traversal.
start_k :
Number of documents to fetch via similarity for starting the traversal.
Added to any initial roots provided to the traversal.
adjacent_k :
Number of documents to fetch for each outgoing edge.
traverse_k :
Maximum number of nodes to traverse outgoing edges from before returning.
max_depth :
Maximum traversal depth. If `None`, there is no limit.
"""

k: int = 5
start_k: int = 4
adjacent_k: int = 10
traverse_k: int = 4 # max_traverse?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't currently used anywhere.

max_depth: int | None = None

_query_embedding: list[float] = dataclasses.field(default_factory=list)

@abc.abstractmethod
def discover_nodes(self, nodes: dict[str, Node]) -> None:
def iteration(self, *, nodes: dict[str, Node], tracker: NodeTracker) -> None:
"""
Add discovered nodes to the strategy.
Process the newly discovered nodes on each iteration.

This method updates the strategy's state with nodes discovered during
the traversal process.
This method should call `traverse` and/or `select` as appropriate
to update the nodes that need to be traversed in this iteration or
selected at the end of the retrieval, respectively.

Parameters
----------
Expand All @@ -54,43 +80,20 @@ def discover_nodes(self, nodes: dict[str, Node]) -> None:
"""
...

@abc.abstractmethod
def select_nodes(self, *, limit: int) -> Iterable[Node]:
"""
Select discovered nodes to visit in the next iteration.

This method determines which nodes will be traversed next. If it returns
an empty list, traversal ends even if fewer than `k` nodes have been selected.

Parameters
----------
limit :
Maximum number of nodes to select.

Returns
-------
:
Selected nodes for the next iteration. Traversal ends if this is empty.
"""
...

def finalize_nodes(self, nodes: Iterable[Node]) -> Iterable[Node]:
def finalize_nodes(self, selected: Iterable[Node]) -> Iterable[Node]:
"""
Finalize the selected nodes.

This method is called before returning the final set of nodes.

Parameters
----------
nodes :
Nodes selected for finalization.

Returns
-------
:
Finalized nodes.
"""
return nodes
# Take the first `self.k` selected items.
# Strategies may override finalize to perform reranking if needed.
return list(selected)[: self.k]

@staticmethod
def build(
Expand Down
15 changes: 3 additions & 12 deletions packages/graph-retriever/src/graph_retriever/strategies/eager.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""Provide eager (breadth-first) traversal strategy."""

import dataclasses
from collections.abc import Iterable

from typing_extensions import override

from graph_retriever.strategies.base import Strategy
from graph_retriever.strategies.base import NodeTracker, Strategy
from graph_retriever.types import Node


Expand All @@ -32,14 +31,6 @@ class Eager(Strategy):
Maximum traversal depth. If `None`, there is no limit.
"""

_nodes: list[Node] = dataclasses.field(default_factory=list)

@override
def discover_nodes(self, nodes: dict[str, Node]) -> None:
self._nodes.extend(nodes.values())

@override
def select_nodes(self, *, limit: int) -> Iterable[Node]:
nodes = self._nodes[:limit]
self._nodes = []
return nodes
def iteration(self, nodes: dict[str, Node], tracker: NodeTracker) -> None:
tracker.select_and_traverse(nodes)
105 changes: 48 additions & 57 deletions packages/graph-retriever/src/graph_retriever/strategies/mmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from numpy.typing import NDArray
from typing_extensions import override

from graph_retriever.strategies.base import Strategy
from graph_retriever.strategies.base import NodeTracker, Strategy
from graph_retriever.types import Node
from graph_retriever.utils.math import cosine_similarity

Expand Down Expand Up @@ -203,8 +203,7 @@ def _pop_candidate(

return candidate, embedding

@override
def select_nodes(self, *, limit: int) -> Iterable[Node]:
def _next(self) -> dict[str, Node]:
"""
Select and pop the best item being considered.

Expand All @@ -214,10 +213,8 @@ def select_nodes(self, *, limit: int) -> Iterable[Node]:
-------
A tuple containing the ID of the best item.
"""
if limit == 0:
return []
if self._best_id is None or self._best_score < self.min_mmr_score:
return []
return {}

# Get the selection and remove from candidates.
selected_id = self._best_id
Expand Down Expand Up @@ -250,61 +247,55 @@ def select_nodes(self, *, limit: int) -> Iterable[Node]:
self._best_score = candidate.score
self._best_id = candidate.node.id

return [selected_node]
return {selected_node.id: selected_node}

@override
def discover_nodes(self, nodes: dict[str, Node]) -> None:
def iteration(self, nodes: dict[str, Node], tracker: NodeTracker) -> None:
"""Add candidates to the consideration set."""
# Determine the keys to actually include.
# These are the candidates that aren't already selected
# or under consideration.

include_ids_set = set(nodes.keys())
include_ids_set.difference_update(self._selected_ids)
include_ids_set.difference_update(self._candidate_id_to_index.keys())
include_ids = list(include_ids_set)

# Now, build up a matrix of the remaining candidate embeddings.
# And add them to the
new_embeddings: NDArray[np.float32] = np.ndarray(
(
len(include_ids),
self._dimensions,
if len(nodes) > 0:
# Build up a matrix of the remaining candidate embeddings.
# And add them to the candidate set
new_embeddings: NDArray[np.float32] = np.ndarray(
(
len(nodes),
self._dimensions,
)
)
)
offset = self._candidate_embeddings.shape[0]
for index, candidate_id in enumerate(include_ids):
self._candidate_id_to_index[candidate_id] = offset + index
new_embeddings[index] = nodes[candidate_id].embedding

# Compute the similarity to the query.
similarity = cosine_similarity(new_embeddings, self._nd_query_embedding)

# Compute the distance metrics of all of pairs in the selected set with
# the new candidates.
redundancy = cosine_similarity(
new_embeddings, self._already_selected_embeddings()
)
for index, candidate_id in enumerate(include_ids):
max_redundancy = 0.0
if redundancy.shape[0] > 0:
max_redundancy = redundancy[index].max()
candidate = _MmrCandidate(
node=nodes[candidate_id],
similarity=similarity[index][0],
weighted_similarity=self.lambda_mult * similarity[index][0],
weighted_redundancy=self._lambda_mult_complement * max_redundancy,
offset = self._candidate_embeddings.shape[0]
for index, candidate_id in enumerate(nodes.keys()):
self._candidate_id_to_index[candidate_id] = offset + index
new_embeddings[index] = nodes[candidate_id].embedding

# Compute the similarity to the query.
similarity = cosine_similarity(new_embeddings, self._nd_query_embedding)

# Compute the distance metrics of all of pairs in the selected set with
# the new candidates.
redundancy = cosine_similarity(
new_embeddings, self._already_selected_embeddings()
)
self._candidates.append(candidate)

if candidate.score >= self._best_score:
self._best_score = candidate.score
self._best_id = candidate.node.id
for index, candidate_id in enumerate(nodes.keys()):
max_redundancy = 0.0
if redundancy.shape[0] > 0:
max_redundancy = redundancy[index].max()
candidate = _MmrCandidate(
node=nodes[candidate_id],
similarity=similarity[index][0],
weighted_similarity=self.lambda_mult * similarity[index][0],
weighted_redundancy=self._lambda_mult_complement * max_redundancy,
)
self._candidates.append(candidate)

if candidate.score >= self._best_score:
self._best_score = candidate.score
self._best_id = candidate.node.id

# Add the new embeddings to the candidate set.
self._candidate_embeddings = np.vstack(
(
self._candidate_embeddings,
new_embeddings,
# Add the new embeddings to the candidate set.
self._candidate_embeddings = np.vstack(
(
self._candidate_embeddings,
new_embeddings,
)
)
)

tracker.select_and_traverse(self._next())
26 changes: 11 additions & 15 deletions packages/graph-retriever/src/graph_retriever/strategies/scored.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import dataclasses
import heapq
from collections.abc import Callable, Iterable
from collections.abc import Callable

from typing_extensions import override

from graph_retriever.strategies.base import Strategy
from graph_retriever.strategies.base import NodeTracker, Strategy
from graph_retriever.types import Node


Expand All @@ -13,33 +13,29 @@ def __init__(self, score: float, node: Node) -> None:
self.score = score
self.node = node

def __lt__(self, other) -> bool:
def __lt__(self, other: "_ScoredNode") -> bool:
return other.score < self.score


@dataclasses.dataclass
class Scored(Strategy):
"""Strategy selecing nodes using a scoring function."""
"""Strategy selecting nodes using a scoring function."""

scorer: Callable[[Node], float]
_nodes: list[_ScoredNode] = dataclasses.field(default_factory=list)

per_iteration_limit: int | None = None
per_iteration_limit: int = 2

@override
def discover_nodes(self, nodes: dict[str, Node]) -> None:
def iteration(self, nodes: dict[str, Node], tracker: NodeTracker) -> None:
for node in nodes.values():
heapq.heappush(self._nodes, _ScoredNode(self.scorer(node), node))

@override
def select_nodes(self, *, limit: int) -> Iterable[Node]:
if self.per_iteration_limit and self.per_iteration_limit < limit:
limit = self.per_iteration_limit

selected = []
for _x in range(limit):
selected = {}
for _x in range(self.per_iteration_limit):
if not self._nodes:
break

selected.append(heapq.heappop(self._nodes).node)
return selected
node = heapq.heappop(self._nodes).node
selected[node.id] = node
tracker.select_and_traverse(selected)
Loading