Skip to content

Commit

Permalink
fix(agents-api): MMR limit check
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Mar 2, 2025
1 parent fd34113 commit 5c4ed4e
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 37 deletions.
13 changes: 8 additions & 5 deletions agents-api/agents_api/queries/chat/gather_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ async def gather_messages(
):
message["content"] = message["content"][0]["text"].strip()

# If recall is disabled, return early
if not recall:
return past_messages, []

Expand All @@ -83,11 +84,11 @@ async def gather_messages(
)
recall_options = session.recall_options

# Return early if recall options not configured
if not recall_options:
# If recall is enabled but recall options are not configured, return early
if recall_options is None:
return past_messages, []

# Get messages to search from
# If recall is enabled and recall options are configured, get messages to search from
search_messages = [
msg
for msg in (past_messages + new_raw_messages)[-(recall_options.num_search_messages) :]
Expand Down Expand Up @@ -175,9 +176,11 @@ async def gather_messages(
indices = maximal_marginal_relevance(
np.asarray(query_embedding),
[doc.snippet.embedding for doc in docs_with_embeddings],
k=recall_options.limit,
k=min(recall_options.limit, len(docs_with_embeddings)),
lambda_mult=1 - recall_options.mmr_strength,
)
doc_references = [docs_with_embeddings[i] for i in indices]
doc_references = [
doc for i, doc in enumerate(docs_with_embeddings) if i in set(indices)
]

return past_messages, doc_references
8 changes: 4 additions & 4 deletions agents-api/agents_api/routers/docs/search_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ async def search_user_docs(
indices = maximal_marginal_relevance(
np.asarray(params["embedding"]),
[doc.snippet.embedding for doc in docs_with_embeddings],
k=search_params.limit,
k=min(search_params.limit, len(docs_with_embeddings)),
lambda_mult=1 - search_params.mmr_strength,
)
docs = [docs_with_embeddings[i] for i in indices]
docs = [doc for i, doc in enumerate(docs_with_embeddings) if i in set(indices)]

end = time.time()

Expand Down Expand Up @@ -121,10 +121,10 @@ async def search_agent_docs(
indices = maximal_marginal_relevance(
np.asarray(params["embedding"]),
[doc.snippet.embedding for doc in docs_with_embeddings],
k=search_params.limit,
k=min(search_params.limit, len(docs_with_embeddings)),
lambda_mult=1 - search_params.mmr_strength,
)
docs = [docs_with_embeddings[i] for i in indices]
docs = [doc for i, doc in enumerate(docs_with_embeddings) if i in set(indices)]

end = time.time()

Expand Down
2 changes: 1 addition & 1 deletion agents-api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ dev = [
"pyjwt>=2.10.1",
"pyright>=1.1.391",
"pytype>=2024.10.11",
"ruff>=0.8.4",
"ruff>=0.9.0",
"sqlvalidator>=0.0.20",
"testcontainers[postgres,localstack]>=4.9.0",
"ward>=0.68.0b0",
Expand Down
8 changes: 4 additions & 4 deletions agents-api/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion integrations-service/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dev = [
"pytest-asyncio>=0.24.0",
"pytest-cov>=6.0.0",
"pytype>=2024.10.11",
"ruff>=0.8.1",
"ruff>=0.9.0",
]

[tool.pytest.ini_options]
Expand Down
44 changes: 22 additions & 22 deletions integrations-service/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 5c4ed4e

Please sign in to comment.