diff --git a/sdk/python/feast/expediagroup/vectordb/elasticsearch_online_store.py b/sdk/python/feast/expediagroup/vectordb/elasticsearch_online_store.py index 55359f7880..15715c2ab4 100644 --- a/sdk/python/feast/expediagroup/vectordb/elasticsearch_online_store.py +++ b/sdk/python/feast/expediagroup/vectordb/elasticsearch_online_store.py @@ -75,28 +75,29 @@ class ElasticsearchOnlineStoreConfig(FeastConfigBaseModel): """ The number of rows to write in a single batch """ -class ElasticsearchConnectionManager: - def __init__(self, online_config: RepoConfig): - self.online_config = online_config - - def __enter__(self): - # Connecting to Elasticsearch - logger.info( - f"Connecting to Elasticsearch with endpoint {self.online_config.endpoint}" - ) - self.client = Elasticsearch( - self.online_config.endpoint, - basic_auth=(self.online_config.username, self.online_config.password), - ) - return self.client +class ElasticsearchOnlineStore(OnlineStore): + _client: Optional[Elasticsearch] = None + + def _get_client(self, config: RepoConfig) -> Elasticsearch: + online_store_config = config.online_store + assert isinstance(online_store_config, ElasticsearchOnlineStoreConfig) - def __exit__(self, exc_type, exc_value, traceback): - # Disconnecting from Elasticsearch - logger.info("Closing the connection to Elasticsearch") - self.client.transport.close() + user = online_store_config.username if online_store_config.username is not None else "" + password = ( + online_store_config.password + if online_store_config.password is not None + else "" + ) + if self._client: + return self._client + else: + self._client = Elasticsearch( + hosts=online_store_config.endpoint, + basic_auth=(user, password), + ) + return self._client -class ElasticsearchOnlineStore(OnlineStore): def _get_bulk_documents(self, index_name, data): for entity_key, values, timestamp, created_ts in data: id_val = self._get_value_from_value_proto(entity_key.entity_values[0]) @@ -114,7 +115,7 @@ def online_write_batch( ], progress: Optional[Callable[[int], Any]], ) -> None: - with ElasticsearchConnectionManager(config.online_store) as es: + with self._get_client(config) as es: resp = es.indices.exists(index=table.name) if not resp.body: self._create_index(es, table) @@ -143,7 +144,7 @@ def online_read( entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: - with ElasticsearchConnectionManager(config.online_store) as es: + with self._get_client(config) as es: id_list = [] for entity in entity_keys: for val in entity.entity_values: @@ -196,7 +197,7 @@ def update( entities_to_keep: Sequence[Entity], partial: bool, ): - with ElasticsearchConnectionManager(config.online_store) as es: + with self._get_client(config.online_store) as es: for fv in tables_to_delete: resp = es.indices.exists(index=fv.name) if resp.body: diff --git a/sdk/python/tests/expediagroup/test_elasticsearch_online_store.py b/sdk/python/tests/expediagroup/test_elasticsearch_online_store.py index 331c01fee6..d6f81d7e80 100644 --- a/sdk/python/tests/expediagroup/test_elasticsearch_online_store.py +++ b/sdk/python/tests/expediagroup/test_elasticsearch_online_store.py @@ -5,10 +5,10 @@ import pytest +from elasticsearch import Elasticsearch from feast import FeatureView from feast.entity import Entity from feast.expediagroup.vectordb.elasticsearch_online_store import ( - ElasticsearchConnectionManager, ElasticsearchOnlineStore, ElasticsearchOnlineStoreConfig, ) @@ -48,6 +48,21 @@ ] +class ElasticsearchConnectionManager: + def __init__(self, online_config: RepoConfig): + self.online_config = online_config + def __enter__(self): + # Connecting to Elasticsearch + self.client = Elasticsearch( + self.online_config.endpoint, + basic_auth=(self.online_config.username, self.online_config.password), + ) + return self.client + def __exit__(self, exc_type, exc_value, traceback): + # Disconnecting from Elasticsearch + self.client.transport.close() + + @pytest.fixture(scope="session") def repo_config(embedded_elasticsearch): return RepoConfig(