From 5797451d00949ea798c3603ee3119ff3b185f173 Mon Sep 17 00:00:00 2001 From: Muspi Merol Date: Sun, 26 May 2024 14:49:24 +0800 Subject: [PATCH] Abort other fetches when resolution fails --- micropip/_compat_in_pyodide.py | 28 +++++++++++++++++++++++++--- micropip/transaction.py | 12 +++++++++++- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/micropip/_compat_in_pyodide.py b/micropip/_compat_in_pyodide.py index e47aadf..93c18ea 100644 --- a/micropip/_compat_in_pyodide.py +++ b/micropip/_compat_in_pyodide.py @@ -1,3 +1,4 @@ +from asyncio import CancelledError from pathlib import Path from urllib.parse import urlparse @@ -7,7 +8,7 @@ try: import pyodide_js - from js import Object + from js import AbortController, Object from pyodide_js import loadedPackages, loadPackage from pyodide_js._api import ( # type: ignore[import] loadBinaryFile, @@ -21,6 +22,27 @@ raise # Otherwise, this is pytest test collection so let it go. +if IN_BROWSER: + + async def _pyfetch(urls: str, **kwargs): + if "signal" in kwargs: + return await pyfetch(urls, **kwargs) + + controler = AbortController.new() + kwargs["signal"] = controler.signal + + async def fetch_with_abort(): + try: + return await pyfetch(urls, **kwargs) + except CancelledError: + controler.abort() + raise + + return await fetch_with_abort() + +else: + _pyfetch = pyfetch + async def fetch_bytes(url: str, kwargs: dict[str, str]) -> bytes: parsed_url = urlparse(url) @@ -29,13 +51,13 @@ async def fetch_bytes(url: str, kwargs: dict[str, str]) -> bytes: if parsed_url.scheme == "file": return (await loadBinaryFile(parsed_url.path)).to_bytes() - return await (await pyfetch(url, **kwargs)).bytes() + return await (await _pyfetch(url, **kwargs)).bytes() async def fetch_string_and_headers( url: str, kwargs: dict[str, str] ) -> tuple[str, dict[str, str]]: - response = await pyfetch(url, **kwargs) + response = await _pyfetch(url, **kwargs) content = await response.string() # TODO: replace with response.headers when pyodide>= 0.24 is released diff --git a/micropip/transaction.py b/micropip/transaction.py index bbc91e2..6686157 100644 --- a/micropip/transaction.py +++ b/micropip/transaction.py @@ -52,7 +52,17 @@ async def gather_requirements( for requirement in requirements: requirement_promises.append(self.add_requirement(requirement)) - await asyncio.gather(*requirement_promises) + futures: list[asyncio.Future] = [] + try: + for coro in requirement_promises: + futures.append(asyncio.ensure_future(coro)) + await asyncio.gather(*futures) + except ValueError: + if not self.keep_going: + for future in futures: + if not future.done(): + future.cancel() + raise async def add_requirement(self, req: str | Requirement) -> None: if isinstance(req, Requirement):