diff --git a/pyproject.toml b/pyproject.toml index 33abbd5d9..0c66e2f50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ dev = [ "aiohttp", "pytest", "pytest-cov", + "pytest-mock", "pytest-asyncio", "coverage", "jinja2", diff --git a/test_elasticsearch/test_dsl/conftest.py b/test_elasticsearch/test_dsl/conftest.py index 2e5fa91af..653b54588 100644 --- a/test_elasticsearch/test_dsl/conftest.py +++ b/test_elasticsearch/test_dsl/conftest.py @@ -22,7 +22,7 @@ import time from datetime import datetime from typing import Any, AsyncGenerator, Dict, Generator, Tuple, cast -from unittest import SkipTest, TestCase +from unittest import SkipTest from unittest.mock import AsyncMock, Mock import pytest_asyncio @@ -46,22 +46,21 @@ create_flat_git_index, create_git_index, ) +from ..utils import CA_CERTS -if "ELASTICSEARCH_URL" in os.environ: - ELASTICSEARCH_URL = os.environ["ELASTICSEARCH_URL"] -else: - ELASTICSEARCH_URL = "http://localhost:9200" - -def get_test_client(wait: bool = True, **kwargs: Any) -> Elasticsearch: +def get_test_client(elasticsearch_url, wait: bool = True, **kwargs: Any) -> Elasticsearch: # construct kwargs from the environment kw: Dict[str, Any] = {"request_timeout": 30} + if elasticsearch_url.startswith("https://"): + kw["ca_certs"] = CA_CERTS + if "PYTHON_CONNECTION_CLASS" in os.environ: kw["node_class"] = os.environ["PYTHON_CONNECTION_CLASS"] kw.update(kwargs) - client = Elasticsearch(ELASTICSEARCH_URL, **kw) + client = Elasticsearch(elasticsearch_url, **kw) # wait for yellow status for tries_left in range(100 if wait else 1, 0, -1): @@ -76,15 +75,17 @@ def get_test_client(wait: bool = True, **kwargs: Any) -> Elasticsearch: raise SkipTest("Elasticsearch failed to start.") -async def get_async_test_client(wait: bool = True, **kwargs: Any) -> AsyncElasticsearch: +async def get_async_test_client( + elasticsearch_url, wait: bool = True, **kwargs: Any +) -> AsyncElasticsearch: # construct kwargs from the environment kw: Dict[str, Any] = {"request_timeout": 30} - if "PYTHON_CONNECTION_CLASS" in os.environ: - kw["node_class"] = os.environ["PYTHON_CONNECTION_CLASS"] + if elasticsearch_url.startswith("https://"): + kw["ca_certs"] = CA_CERTS kw.update(kwargs) - client = AsyncElasticsearch(ELASTICSEARCH_URL, **kw) + client = AsyncElasticsearch(elasticsearch_url, **kw) # wait for yellow status for tries_left in range(100 if wait else 1, 0, -1): @@ -100,36 +101,6 @@ async def get_async_test_client(wait: bool = True, **kwargs: Any) -> AsyncElasti raise SkipTest("Elasticsearch failed to start.") -class ElasticsearchTestCase(TestCase): - client: Elasticsearch - - @staticmethod - def _get_client() -> Elasticsearch: - return get_test_client() - - @classmethod - def setup_class(cls) -> None: - cls.client = cls._get_client() - - def teardown_method(self, _: Any) -> None: - # Hidden indices expanded in wildcards in ES 7.7 - expand_wildcards = ["open", "closed"] - if self.es_version() >= (7, 7): - expand_wildcards.append("hidden") - - self.client.indices.delete_data_stream( - name="*", expand_wildcards=expand_wildcards - ) - self.client.indices.delete(index="*", expand_wildcards=expand_wildcards) - self.client.indices.delete_template(name="*") - self.client.indices.delete_index_template(name="*") - - def es_version(self) -> Tuple[int, ...]: - if not hasattr(self, "_es_version"): - self._es_version = _get_version(self.client.info()["version"]["number"]) - return self._es_version - - def _get_version(version_string: str) -> Tuple[int, ...]: if "." not in version_string: return () @@ -138,9 +109,11 @@ def _get_version(version_string: str) -> Tuple[int, ...]: @fixture(scope="session") -def client() -> Elasticsearch: +def client(elasticsearch_url) -> Elasticsearch: try: - connection = get_test_client(wait="WAIT_FOR_ES" in os.environ) + connection = get_test_client( + elasticsearch_url, wait="WAIT_FOR_ES" in os.environ + ) add_connection("default", connection) return connection except SkipTest: @@ -148,9 +121,11 @@ def client() -> Elasticsearch: @pytest_asyncio.fixture -async def async_client() -> AsyncGenerator[AsyncElasticsearch, None]: +async def async_client(elasticsearch_url) -> AsyncGenerator[AsyncElasticsearch, None]: try: - connection = await get_async_test_client(wait="WAIT_FOR_ES" in os.environ) + connection = await get_async_test_client( + elasticsearch_url, wait="WAIT_FOR_ES" in os.environ + ) add_async_connection("default", connection) yield connection await connection.close()