diff --git a/python/beeai_framework/tools/search/__init__.py b/python/beeai_framework/tools/search/__init__.py index 73bfc656..21e65808 100644 --- a/python/beeai_framework/tools/search/__init__.py +++ b/python/beeai_framework/tools/search/__init__.py @@ -16,5 +16,6 @@ # manually defined import order is import here to avoid circular imports from beeai_framework.tools.search.base import SearchToolResult, SearchToolOutput from beeai_framework.tools.search.duckduckgo import DuckDuckGoSearchTool +from beeai_framework.tools.search.wikipedia import WikipediaTool -__all__ = ["DuckDuckGoSearchTool", "SearchToolOutput", "SearchToolResult"] +__all__ = ["DuckDuckGoSearchTool", "SearchToolOutput", "SearchToolResult", "WikipediaTool"] diff --git a/python/beeai_framework/tools/search/wikipedia.py b/python/beeai_framework/tools/search/wikipedia.py index 47fb1183..1b254eab 100644 --- a/python/beeai_framework/tools/search/wikipedia.py +++ b/python/beeai_framework/tools/search/wikipedia.py @@ -13,16 +13,65 @@ # limitations under the License. +from typing import Any + +import wikipediaapi +from pydantic import BaseModel, Field + +from beeai_framework.tools.search import SearchToolOutput, SearchToolResult from beeai_framework.tools.tool import Tool -class WikipediaTool(Tool): +class WikipediaToolInput(BaseModel): + query: str = Field(description="Search query, name of the Wikipedia page.") + full_text: bool = Field(description="If set to true will return the full text of the page.", default=False) + section_titles: bool = Field(description="If set to true returns section titles as the description.", default=False) + language: str | None = Field(description="Retrieves specified language version if available.", default=None) + + +class WikipediaToolResult(SearchToolResult): + pass + + +class WikipediaToolOutput(SearchToolOutput): + pass + + +class WikipediaTool(Tool[WikipediaToolInput]): name = "Wikipedia" - description = "Search factual and historical information, including biography, history, politics, geography, society, culture, science, technology, people, animal species, mathematics, and other subjects." # noqa: E501 + description = "Search factual and historical information, including biography, \ + history, politics, geography, society, culture, science, technology, people, \ + animal species, mathematics, and other subjects." + input_schema = WikipediaToolInput + client = wikipediaapi.Wikipedia( + user_agent="beeai-framework https://github.com/i-am-bee/beeai-framework", language="en" + ) + + def get_section_titles(self, sections: wikipediaapi.WikipediaPage.sections) -> str: + titles = [] + for section in sections: + titles.append(section.title) + return ",".join(str(title) for title in titles) + + def _run(self, input: WikipediaToolInput, _: Any | None = None) -> WikipediaToolOutput: + page_py = self.client.page(input.query) + + if page_py.exists(): + if input.language is not None and input.language in page_py.langlinks: + page_py = page_py.langlinks[input.language] - def input_schema(self) -> str: - # TODO: remove hard code - return '{"type":"object","properties":{"query":{"type":"string","format":"date","description":"Name of the wikipedia page, for example \'New York\'"}}}' # noqa: E501 + if input.section_titles: + description_output = self.get_section_titles(page_py.sections) + elif input.full_text: + description_output = page_py.text + else: + description_output = page_py.summary - def _run(self) -> None: - pass + search_results: list[WikipediaToolResult] = [ + WikipediaToolResult( + title=input.query or "", description=description_output or "", url=page_py.fullurl or "" + ) + ] + return WikipediaToolOutput(search_results) + else: + raise Exception(f"No Wikipedia page matched the search term: {input.query}.") diff --git a/python/docs/tools.md b/python/docs/tools.md index b6652d88..7315e594 100644 --- a/python/docs/tools.md +++ b/python/docs/tools.md @@ -192,6 +192,33 @@ if __name__ == "__main__": _Source: [examples/tools/openmeteo.py](/examples/tools/openmeteo.py)_ +### Usage with Wikipedia + + + +```py +import asyncio + +from beeai_framework.tools.search.wikipedia import ( + WikipediaTool, + WikipediaToolInput, +) + + +async def main() -> None: + wikipedia_client = WikipediaTool(full_text=True) + input = WikipediaToolInput(query="bee") + result = wikipedia_client.run(input) + print(result.get_text_content()) + + +if __name__ == "__main__": + asyncio.run(main()) + +``` + +_Source: [examples/tools/wikipedia.py](/examples/tools/wikipedia.py)_ + ## Writing a new tool To create a new tool, you have the following options on how to do that: diff --git a/python/examples/tools/wikipedia.py b/python/examples/tools/wikipedia.py new file mode 100644 index 00000000..6eaca1ff --- /dev/null +++ b/python/examples/tools/wikipedia.py @@ -0,0 +1,17 @@ +import asyncio + +from beeai_framework.tools.search.wikipedia import ( + WikipediaTool, + WikipediaToolInput, +) + + +async def main() -> None: + wikipedia_client = WikipediaTool(full_text=True) + input = WikipediaToolInput(query="bee") + result = wikipedia_client.run(input) + print(result.get_text_content()) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/poetry.lock b/python/poetry.lock index 9ea76338..2615c5de 100644 --- a/python/poetry.lock +++ b/python/poetry.lock @@ -3184,6 +3184,20 @@ files = [ beautifulsoup4 = "*" requests = ">=2.0.0,<3.0.0" +[[package]] +name = "wikipedia-api" +version = "0.8.1" +description = "Python Wrapper for Wikipedia" +optional = false +python-versions = "*" +groups = ["main"] +files = [ + {file = "wikipedia_api-0.8.1.tar.gz", hash = "sha256:b31e93b3f5407c1a1ba413ed7326a05379a3c270df6cf6a211aca67a14c5658b"}, +] + +[package.dependencies] +requests = "*" + [[package]] name = "yarl" version = "1.18.3" @@ -3304,4 +3318,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">= 3.11,<4.0" -content-hash = "02321113cea97f76c799d196c6044e9f9f9e6955a20c4fd5e7b56cb2ee0ebb0b" +content-hash = "ff93be7be761dfdd23ac1eba1c93862d0e95fb095484fc81c2cd39f537626cac" diff --git a/python/pyproject.toml b/python/pyproject.toml index b7ba854f..30a03c4b 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -38,6 +38,7 @@ wikipedia = "^1.4.0" mcp = "^1.2.0" duckduckgo-search = "^7.3.2" json-repair = "^0.39.0" +wikipedia-api = "^0.8.1" [tool.poetry.group.dev.dependencies] pytest = "^8.3.4" diff --git a/python/tests/tools/test_wikipedia.py b/python/tests/tools/test_wikipedia.py new file mode 100644 index 00000000..9772679b --- /dev/null +++ b/python/tests/tools/test_wikipedia.py @@ -0,0 +1,62 @@ +# Copyright 2025 IBM Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from beeai_framework.tools import ToolInputValidationError +from beeai_framework.tools.search.wikipedia import ( + WikipediaTool, + WikipediaToolInput, + WikipediaToolOutput, +) + + +@pytest.fixture +def tool() -> WikipediaTool: + return WikipediaTool() + + +@pytest.mark.e2e +def test_call_invalid_input_type(tool: WikipediaTool) -> None: + with pytest.raises(ToolInputValidationError): + tool.run(input={"search": "Bee"}) + + +@pytest.mark.e2e +def test_output(tool: WikipediaTool) -> None: + result = tool.run(input=WikipediaToolInput(query="bee")) + assert type(result) is WikipediaToolOutput + assert "Bees are winged insects closely related to wasps and ants" in result.get_text_content() + + +@pytest.mark.e2e +def test_full_text_output(tool: WikipediaTool) -> None: + result = tool.run(input=WikipediaToolInput(query="bee", full_text=True)) + assert type(result) is WikipediaToolOutput + assert "n-triscosane" in result.get_text_content() + + +@pytest.mark.e2e +def test_section_titles(tool: WikipediaTool) -> None: + result = tool.run(input=WikipediaToolInput(query="bee", section_titles=True)) + assert type(result) is WikipediaToolOutput + assert "Characteristics" in result.get_text_content() + + +@pytest.mark.e2e +def test_alternate_language(tool: WikipediaTool) -> None: + result = tool.run(input=WikipediaToolInput(query="bee", language="fr")) + assert type(result) is WikipediaToolOutput + assert "Les abeilles (Anthophila) forment un clade d'insectes" in result.get_text_content()