Skip to content

Commit

Permalink
fix: use a persistent client based on opensource implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
piket committed Jul 29, 2024
1 parent 648a237 commit ee5575d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,28 +75,29 @@ class ElasticsearchOnlineStoreConfig(FeastConfigBaseModel):
""" The number of rows to write in a single batch """


class ElasticsearchConnectionManager:
def __init__(self, online_config: RepoConfig):
self.online_config = online_config

def __enter__(self):
# Connecting to Elasticsearch
logger.info(
f"Connecting to Elasticsearch with endpoint {self.online_config.endpoint}"
)
self.client = Elasticsearch(
self.online_config.endpoint,
basic_auth=(self.online_config.username, self.online_config.password),
)
return self.client
class ElasticsearchOnlineStore(OnlineStore):
_client: Optional[Elasticsearch] = None

def _get_client(self, config: RepoConfig) -> Elasticsearch:
online_store_config = config.online_store
assert isinstance(online_store_config, ElasticsearchOnlineStoreConfig)

def __exit__(self, exc_type, exc_value, traceback):
# Disconnecting from Elasticsearch
logger.info("Closing the connection to Elasticsearch")
self.client.transport.close()
user = online_store_config.username if online_store_config.username is not None else ""
password = (
online_store_config.password
if online_store_config.password is not None
else ""
)

if self._client:
return self._client
else:
self._client = Elasticsearch(
hosts=online_store_config.endpoint,
basic_auth=(user, password),
)
return self._client

class ElasticsearchOnlineStore(OnlineStore):
def _get_bulk_documents(self, index_name, data):
for entity_key, values, timestamp, created_ts in data:
id_val = self._get_value_from_value_proto(entity_key.entity_values[0])
Expand All @@ -114,7 +115,7 @@ def online_write_batch(
],
progress: Optional[Callable[[int], Any]],
) -> None:
with ElasticsearchConnectionManager(config.online_store) as es:
with self._get_client(config) as es:
resp = es.indices.exists(index=table.name)
if not resp.body:
self._create_index(es, table)
Expand Down Expand Up @@ -143,7 +144,7 @@ def online_read(
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
with ElasticsearchConnectionManager(config.online_store) as es:
with self._get_client(config) as es:
id_list = []
for entity in entity_keys:
for val in entity.entity_values:
Expand Down Expand Up @@ -196,7 +197,7 @@ def update(
entities_to_keep: Sequence[Entity],
partial: bool,
):
with ElasticsearchConnectionManager(config.online_store) as es:
with self._get_client(config.online_store) as es:
for fv in tables_to_delete:
resp = es.indices.exists(index=fv.name)
if resp.body:
Expand Down
17 changes: 16 additions & 1 deletion sdk/python/tests/expediagroup/test_elasticsearch_online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import pytest

from elasticsearch import Elasticsearch
from feast import FeatureView
from feast.entity import Entity
from feast.expediagroup.vectordb.elasticsearch_online_store import (
ElasticsearchConnectionManager,
ElasticsearchOnlineStore,
ElasticsearchOnlineStoreConfig,
)
Expand Down Expand Up @@ -48,6 +48,21 @@
]


class ElasticsearchConnectionManager:
def __init__(self, online_config: RepoConfig):
self.online_config = online_config
def __enter__(self):
# Connecting to Elasticsearch
self.client = Elasticsearch(
self.online_config.endpoint,
basic_auth=(self.online_config.username, self.online_config.password),
)
return self.client
def __exit__(self, exc_type, exc_value, traceback):
# Disconnecting from Elasticsearch
self.client.transport.close()


@pytest.fixture(scope="session")
def repo_config(embedded_elasticsearch):
return RepoConfig(
Expand Down

0 comments on commit ee5575d

Please sign in to comment.