diff --git a/langchain/document_loaders/rss.py b/langchain/document_loaders/rss.py new file mode 100644 index 0000000000000..eaee57be05b45 --- /dev/null +++ b/langchain/document_loaders/rss.py @@ -0,0 +1,121 @@ +"""Loader that fetches a sitemap and loads those URLs.""" +import re +import itertools +from typing import Any, Callable, List, Optional, Generator + +from langchain.document_loaders.web_base import WebBaseLoader +from langchain.schema import Document + +from lxml import etree + + +def _default_parsing_function_text(content: Any) -> str: + text = "" + if "content" in content: + text = content["content"] + elif "description" in content: + text = content["description"] + + return text + + +def _default_parsing_function_meta(meta: Any) -> str: + r_meta = dict(meta) + if "content" in r_meta: + del r_meta["content"] + + if "description" in r_meta: + del r_meta["description"] + + return r_meta + +class RssLoader(WebBaseLoader): + """Loader that fetches a sitemap and loads those URLs.""" + + def __init__( + self, + web_path: str, + parsing_function_text: Optional[Callable] = None, + parsing_function_meta: Optional[Callable] = None, + ): + """Initialize with webpage path and optional filter URLs. + + Args: + web_path: url of the sitemap + filter_urls: list of strings or regexes that will be applied to filter the + urls that are parsed and loaded + parsing_function: Function to parse bs4.Soup output + """ + + try: + import lxml # noqa:F401 + except ImportError: + raise ValueError( + "lxml package not found, please install it with " "`pip install lxml`" + ) + + super().__init__( + web_path, + header_template=header_template, + ) + + self.parsing_function_text = parsing_function_text or _default_parsing_function_text + self.parsing_function_meta = parsing_function_meta or _default_parsing_function_meta + + self.namespaces = { + 'content': 'http://purl.org/rss/1.0/modules/content/', + 'dc':'http://purl.org/dc/elements/1.1/' + } + self.fields = [ + {"tag": "./link", "field":"source"}, + {"tag": "./title", "field":"title"}, + {"tag": "./category", "field": "category", "multi": True}, + {"tag": "./pubDate", "field":"publication_date"}, + {"tag": "./dc:creator", "field": "author"}, + {"tag": "./description", "field": "description", "type":"html"}, + {"tag": "./content:encoded", "field":"content", "type":"html"}, + ] + self.items_selector = './channel/item' + + def parse_rss(self, root: Any) -> Generator[List[dict], None, None]: + """Parse rss xml and load into a list of dicts.""" + + for item in root.findall(self.items_selector): + meta = {} + for field in self.fields: + element_list = item.findall(field["tag"], namespaces=self.namespaces) + for element in element_list: + text = element.text + + if "type" in field and field["type"] == "html": + soup = BeautifulSoup(text,"html.parser") + text = soup.get_text() + + if field["field"] not in meta: + meta[field["field"]] = [] if "multi" in field and field["multi"] ==True else "" + + if "multi" in field and field["multi"] ==True: + meta[field["field"]] = meta[field["field"]] if "field" in field else [] + meta[field["field"]].append(text) + else: + meta[field["field"]] = text + + yield meta + + + + def load(self) -> List[Document]: + """Load feeds.""" + + docs: List[Document] = list() + for feed in self.web_paths: + xml = self.session.get(feed) + root = etree.fromstring(xml) + + for item in self.parse_rss(root): + text = self.parsing_function_text(item) + metadata = self.parsing_function_meta(item) + + docs.append(Document(page_content=text, metadata=metadata)) + + return docs \ No newline at end of file diff --git a/langchain/document_loaders/sitemap.py b/langchain/document_loaders/sitemap.py index 3a417dd0b4ac4..7054bd01ae574 100644 --- a/langchain/document_loaders/sitemap.py +++ b/langchain/document_loaders/sitemap.py @@ -1,6 +1,10 @@ """Loader that fetches a sitemap and loads those URLs.""" import re -from typing import Any, Callable, List, Optional +import itertools +from typing import Any, Callable, List, Optional, Iterable, Generator + +from aiohttp.helpers import BasicAuth +from aiohttp.typedefs import StrOrURL from langchain.document_loaders.web_base import WebBaseLoader from langchain.schema import Document @@ -9,6 +13,13 @@ def _default_parsing_function(content: Any) -> str: return str(content.get_text()) +def _default_meta_function(list: dict, _content: Any) -> dict: + return list + +def _batch_block(iterable: Iterable, size: int) -> Generator[List[dict], None, None]: + it = iter(iterable) + while item := list(itertools.islice(it, size)): + yield item class SitemapLoader(WebBaseLoader): """Loader that fetches a sitemap and loads those URLs.""" @@ -18,6 +29,13 @@ def __init__( web_path: str, filter_urls: Optional[List[str]] = None, parsing_function: Optional[Callable] = None, + meta_function: Optional[Callable] = None, + header_template: Optional[dict] = None, + proxy: Optional[StrOrURL] = None, + proxy_auth: Optional[BasicAuth] = None, + cookies: Optional[dict] = None, + blocksize: Optional[int] = None, + blocknum: Optional[int] = None, ): """Initialize with webpage path and optional filter URLs. @@ -26,6 +44,10 @@ def __init__( filter_urls: list of strings or regexes that will be applied to filter the urls that are parsed and loaded parsing_function: Function to parse bs4.Soup output + proxy: proxy url + proxy_auth: proxy server authentication + blocksize: number of sitemap location per block + blocknum: the number of the block that should be loaded - zero indexed """ try: @@ -35,9 +57,19 @@ def __init__( "lxml package not found, please install it with " "`pip install lxml`" ) - super().__init__(web_path) + super().__init__( + web_path, + proxy=proxy, + proxy_auth=proxy_auth, + cookies=cookies, + header_template=header_template, + ) + + self.blocksize = blocksize + self.blocknum = blocknum self.filter_urls = filter_urls + self.meta_function = meta_function or _default_meta_function self.parsing_function = parsing_function or _default_parsing_function def parse_sitemap(self, soup: Any) -> List[dict]: @@ -76,12 +108,23 @@ def load(self) -> List[Document]: els = self.parse_sitemap(soup) + if self.blocksize is not None and self.blocknum is not None: + total_item_count = len(els) + elblocks = list(_batch_block(els, self.blocksize)) + blockcount = len(elblocks) + if blockcount - 1 < self.blocknum: + raise ValueError( + "Selected sitemap does not contain enough blocks for given blocknum" + ) + else: + els = elblocks[self.blocknum] + results = self.scrape_all([el["loc"].strip() for el in els if "loc" in el]) return [ Document( page_content=self.parsing_function(results[i]), - metadata={**{"source": els[i]["loc"]}, **els[i]}, + metadata={**{"source": els[i]["loc"]}, **self.meta_function(els[i], results[i])}, ) for i in range(len(results)) ] diff --git a/langchain/document_loaders/web_base.py b/langchain/document_loaders/web_base.py index 50cf549db0566..218556107c458 100644 --- a/langchain/document_loaders/web_base.py +++ b/langchain/document_loaders/web_base.py @@ -6,6 +6,8 @@ import aiohttp import requests +from aiohttp.helpers import BasicAuth +from aiohttp.typedefs import StrOrURL from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader @@ -47,8 +49,19 @@ class WebBaseLoader(BaseLoader): default_parser: str = "html.parser" """Default parser to use for BeautifulSoup.""" + proxy: Optional[StrOrURL] = None + """aiohttp proxy server""" + + proxy_auth: Optional[BasicAuth] = None + """aio proxy auth""" + def __init__( - self, web_path: Union[str, List[str]], header_template: Optional[dict] = None + self, + web_path: Union[str, List[str]], + header_template: Optional[dict] = None, + proxy: Optional[StrOrURL] = None, + proxy_auth: Optional[BasicAuth] = None, + cookies: Optional[dict] = None, ): """Initialize with webpage path.""" @@ -61,6 +74,8 @@ def __init__( self.web_paths = web_path self.session = requests.Session() + self.proxy = proxy + self.proxy_auth = proxy_auth try: import bs4 # noqa:F401 except ImportError: @@ -68,17 +83,28 @@ def __init__( "bs4 package not found, please install it with " "`pip install bs4`" ) - try: - from fake_useragent import UserAgent - - headers = header_template or default_header_template - headers["User-Agent"] = UserAgent().random - self.session.headers = dict(headers) - except ImportError: - logger.info( - "fake_useragent not found, using default user agent." - "To get a realistic header for requests, `pip install fake_useragent`." - ) + headers = header_template or default_header_template + if ( + "User-Agent" not in headers + or headers["User-Agent"] == "" + or headers["User-Agent"] == None + ): + try: + from fake_useragent import UserAgent + + headers["User-Agent"] = UserAgent().random + except ImportError: + logger.info( + "fake_useragent not found, using default user agent." + "To get a realistic header for requests, `pip install fake_useragent`." + ) + + self.session.headers = dict(headers) + + # Combine cookies + if cookies is None: + cookies = {} + self.session.cookies.update(cookies) @property def web_path(self) -> str: @@ -89,11 +115,16 @@ def web_path(self) -> str: async def _fetch( self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5 ) -> str: - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession( + cookies=self.session.cookies.get_dict() + ) as session: for i in range(retries): try: async with session.get( - url, headers=self.session.headers + url, + headers=self.session.headers, + proxy=self.proxy, + proxy_auth=self.proxy_auth, ) as response: return await response.text() except aiohttp.ClientConnectionError as e: @@ -168,7 +199,14 @@ def _scrape(self, url: str, parser: Union[str, None] = None) -> Any: self._check_parser(parser) - html_doc = self.session.get(url) + proxies = None + if self.proxy is not None: + proxies = { + "http": self.proxy, + "https": self.proxy, + } + + html_doc = self.session.get(url, proxies=proxies) return BeautifulSoup(html_doc.text, parser) def scrape(self, parser: Union[str, None] = None) -> Any: diff --git a/langchain/vectorstores/pgvector.py b/langchain/vectorstores/pgvector.py index 27008eb59f423..24fb4b77e2b34 100644 --- a/langchain/vectorstores/pgvector.py +++ b/langchain/vectorstores/pgvector.py @@ -4,6 +4,7 @@ import enum import logging import uuid +import os from typing import Any, Dict, Iterable, List, Optional, Tuple, Type import sqlalchemy @@ -19,7 +20,7 @@ Base = declarative_base() # type: Any -ADA_TOKEN_COUNT = 1536 +ADA_TOKEN_COUNT = int(os.getenv("PGVECTOR_ADA_TOKEN_COUNT", default="1536")) _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" diff --git a/tests/integration_tests/document_loaders/test_sitemap.py b/tests/integration_tests/document_loaders/test_sitemap.py index 87147ec6309b0..cc5d8e4664161 100644 --- a/tests/integration_tests/document_loaders/test_sitemap.py +++ b/tests/integration_tests/document_loaders/test_sitemap.py @@ -1,5 +1,5 @@ from langchain.document_loaders import SitemapLoader - +import pytest def test_sitemap() -> None: """Test sitemap loader.""" @@ -9,11 +9,34 @@ def test_sitemap() -> None: assert "🦜🔗" in documents[0].page_content +def test_sitemap_block() -> None: + """Test sitemap loader.""" + loader = SitemapLoader("https://langchain.readthedocs.io/sitemap.xml", blocksize=1, blocknum=1) + documents = loader.load() + assert len(documents) == 1 + assert "🦜🔗" in documents[0].page_content + + +def test_sitemap_block_only_one() -> None: + """Test sitemap loader.""" + loader = SitemapLoader("https://langchain.readthedocs.io/sitemap.xml", blocksize=1000000, blocknum=0) + documents = loader.load() + assert len(documents) > 1 + assert "🦜🔗" in documents[0].page_content + + +def test_sitemap_block_does_not_exists() -> None: + """Test sitemap loader.""" + loader = SitemapLoader("https://langchain.readthedocs.io/sitemap.xml", blocksize=1000000, blocknum=15) + with pytest.raises(ValueError): + documents = loader.load() + + def test_filter_sitemap() -> None: """Test sitemap loader.""" loader = SitemapLoader( "https://langchain.readthedocs.io/sitemap.xml", - filter_urls=["https://langchain.readthedocs.io/en/stable/"], + filter_urls=["https://python.langchain.com/en/stable/"], ) documents = loader.load() assert len(documents) == 1