From 0e546347c00c99f273c6aebc33b8165a1ca3065c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Chaves?= Date: Tue, 20 Aug 2024 17:18:24 +0200 Subject: [PATCH] Add an input URL list parameter (#38) --- tests/test_ecommerce.py | 69 ++++++++++++++++++++++ zyte_spider_templates/params.py | 54 ++++++++++++++++- zyte_spider_templates/spiders/base.py | 4 +- zyte_spider_templates/spiders/ecommerce.py | 2 + 4 files changed, 127 insertions(+), 2 deletions(-) diff --git a/tests/test_ecommerce.py b/tests/test_ecommerce.py index ae9dd7c..dfdd264 100644 --- a/tests/test_ecommerce.py +++ b/tests/test_ecommerce.py @@ -395,6 +395,22 @@ def test_metadata(): "title": "URL", "type": "string", }, + "urls": { + "anyOf": [ + {"items": {"type": "string"}, "type": "array"}, + {"type": "null"}, + ], + "default": None, + "description": ( + "Initial URLs for the crawl, separated by new lines. Enter the " + "full URL including http(s), you can copy and paste it from your " + "browser. Example: https://toscrape.com/" + ), + "exclusiveRequired": True, + "group": "inputs", + "title": "URLs", + "widget": "textarea", + }, "urls_file": { "default": "", "description": ( @@ -706,12 +722,24 @@ def test_input_none(): def test_input_multiple(): crawler = get_crawler() + with pytest.raises(ValueError): + EcommerceSpider.from_crawler( + crawler, + url="https://a.example", + urls=["https://b.example"], + ) with pytest.raises(ValueError): EcommerceSpider.from_crawler( crawler, url="https://a.example", urls_file="https://b.example", ) + with pytest.raises(ValueError): + EcommerceSpider.from_crawler( + crawler, + urls=["https://a.example"], + urls_file="https://b.example", + ) def test_url_invalid(): @@ -720,6 +748,47 @@ def test_url_invalid(): EcommerceSpider.from_crawler(crawler, url="foo") +def test_urls(caplog): + crawler = get_crawler() + url = "https://example.com" + + spider = EcommerceSpider.from_crawler(crawler, urls=[url]) + start_requests = list(spider.start_requests()) + assert len(start_requests) == 1 + assert start_requests[0].url == url + assert start_requests[0].callback == spider.parse_navigation + + spider = EcommerceSpider.from_crawler(crawler, urls=url) + start_requests = list(spider.start_requests()) + assert len(start_requests) == 1 + assert start_requests[0].url == url + assert start_requests[0].callback == spider.parse_navigation + + caplog.clear() + spider = EcommerceSpider.from_crawler( + crawler, + urls="https://a.example\n \nhttps://b.example\nhttps://c.example\nfoo\n\n", + ) + assert "'foo', from the 'urls' spider argument, is not a valid URL" in caplog.text + start_requests = list(spider.start_requests()) + assert len(start_requests) == 3 + assert all( + request.callback == spider.parse_navigation for request in start_requests + ) + assert start_requests[0].url == "https://a.example" + assert start_requests[1].url == "https://b.example" + assert start_requests[2].url == "https://c.example" + + caplog.clear() + with pytest.raises(ValueError): + spider = EcommerceSpider.from_crawler( + crawler, + urls="foo\nbar", + ) + assert "'foo', from the 'urls' spider argument, is not a valid URL" in caplog.text + assert "'bar', from the 'urls' spider argument, is not a valid URL" in caplog.text + + def test_urls_file(): crawler = get_crawler() url = "https://example.com" diff --git a/zyte_spider_templates/params.py b/zyte_spider_templates/params.py index f5c246f..d9245a8 100644 --- a/zyte_spider_templates/params.py +++ b/zyte_spider_templates/params.py @@ -1,6 +1,8 @@ import json +import re from enum import Enum -from typing import Dict, Optional, Union +from logging import getLogger +from typing import Dict, List, Optional, Union from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -12,6 +14,8 @@ from .utils import _URL_PATTERN +logger = getLogger(__name__) + @document_enum class ExtractFrom(str, Enum): @@ -110,6 +114,54 @@ class UrlParam(BaseModel): ) +class UrlsParam(BaseModel): + urls: Optional[List[str]] = Field( + title="URLs", + description=( + "Initial URLs for the crawl, separated by new lines. Enter the " + "full URL including http(s), you can copy and paste it from your " + "browser. Example: https://toscrape.com/" + ), + default=None, + json_schema_extra={ + "group": "inputs", + "exclusiveRequired": True, + "widget": "textarea", + }, + ) + + @field_validator("urls", mode="before") + @classmethod + def validate_url_list(cls, value: Union[List[str], str]) -> List[str]: + """Validate a list of URLs. + + If a string is received as input, it is split into multiple strings + on new lines. + + List items that do not match a URL pattern trigger a warning and are + removed from the list. If all URLs are invalid, validation fails. + """ + if isinstance(value, str): + value = value.split("\n") + if not value: + return value + result = [] + for v in value: + v = v.strip() + if not v: + continue + if not re.search(_URL_PATTERN, v): + logger.warning( + f"{v!r}, from the 'urls' spider argument, is not a " + f"valid URL and will be ignored." + ) + continue + result.append(v) + if not result: + raise ValueError(f"No valid URL found in {value!r}") + return result + + class PostalAddress(BaseModel): """ Represents a postal address with various optional components such as diff --git a/zyte_spider_templates/spiders/base.py b/zyte_spider_templates/spiders/base.py index bb9cf29..846b87a 100644 --- a/zyte_spider_templates/spiders/base.py +++ b/zyte_spider_templates/spiders/base.py @@ -11,12 +11,13 @@ MaxRequestsParam, UrlParam, UrlsFileParam, + UrlsParam, ) # Higher priority than command-line-defined settings (40). ARG_SETTING_PRIORITY: int = 50 -_INPUT_FIELDS = ("url", "urls_file") +_INPUT_FIELDS = ("url", "urls", "urls_file") class BaseSpiderParams( @@ -24,6 +25,7 @@ class BaseSpiderParams( MaxRequestsParam, GeolocationParam, UrlsFileParam, + UrlsParam, UrlParam, BaseModel, ): diff --git a/zyte_spider_templates/spiders/ecommerce.py b/zyte_spider_templates/spiders/ecommerce.py index 3882a8c..f8242ce 100644 --- a/zyte_spider_templates/spiders/ecommerce.py +++ b/zyte_spider_templates/spiders/ecommerce.py @@ -126,6 +126,8 @@ def _init_input(self): urls = load_url_list(response.text) self.logger.info(f"Loaded {len(urls)} initial URLs from {urls_file}.") self.start_urls = urls + elif self.args.urls: + self.start_urls = self.args.urls else: self.start_urls = [self.args.url] self.allowed_domains = list(set(get_domain(url) for url in self.start_urls))