Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use persistent es client and batch bulk requests #122

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,29 +71,41 @@ class ElasticsearchOnlineStoreConfig(FeastConfigBaseModel):
password: str
""" password to connect to Elasticsearch """

write_batch_size: Optional[int] = 40
""" The number of rows to write in a single batch """

class ElasticsearchConnectionManager:
def __init__(self, online_config: RepoConfig):
self.online_config = online_config

def __enter__(self):
# Connecting to Elasticsearch
logger.info(
f"Connecting to Elasticsearch with endpoint {self.online_config.endpoint}"
)
self.client = Elasticsearch(
self.online_config.endpoint,
basic_auth=(self.online_config.username, self.online_config.password),
class ElasticsearchOnlineStore(OnlineStore):
_client: Optional[Elasticsearch] = None

def _get_client(self, config: RepoConfig) -> Elasticsearch:
online_store_config = config.online_store
assert isinstance(online_store_config, ElasticsearchOnlineStoreConfig)

user = online_store_config.username if online_store_config.username is not None else ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can the default values be set in ElasticsearchOnlineStoreConfig instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for null safety checks. The config validation is supposed to enforce those values to not be None.

password = (
online_store_config.password
if online_store_config.password is not None
else ""
)
return self.client

def __exit__(self, exc_type, exc_value, traceback):
# Disconnecting from Elasticsearch
logger.info("Closing the connection to Elasticsearch")
self.client.transport.close()
if self._client:
return self._client
else:
self._client = Elasticsearch(
hosts=online_store_config.endpoint,
basic_auth=(user, password),
)
return self._client

def _get_bulk_documents(self, index_name, data):
for entity_key, values, timestamp, created_ts in data:
id_val = self._get_value_from_value_proto(entity_key.entity_values[0])
document = {entity_key.join_keys[0]: id_val}
for feature_name, val in values.items():
document[feature_name] = self._get_value_from_value_proto(val)
yield {"_index": index_name, "_id": id_val, "_source": document}

class ElasticsearchOnlineStore(OnlineStore):
def online_write_batch(
self,
config: RepoConfig,
Expand All @@ -103,24 +115,27 @@ def online_write_batch(
],
progress: Optional[Callable[[int], Any]],
) -> None:
with ElasticsearchConnectionManager(config.online_store) as es:
with self._get_client(config) as es:
resp = es.indices.exists(index=table.name)
if not resp.body:
self._create_index(es, table)
bulk_documents = []
for entity_key, values, timestamp, created_ts in data:
id_val = self._get_value_from_value_proto(entity_key.entity_values[0])
document = {entity_key.join_keys[0]: id_val}
for feature_name, val in values.items():
document[feature_name] = self._get_value_from_value_proto(val)
bulk_documents.append(
{"_index": table.name, "_id": id_val, "_source": document}
)

successes, errors = helpers.bulk(client=es, actions=bulk_documents)

successes = 0
errors: List[Any] = []
error_count = 0
for i in range(0, len(data), config.online_store.write_batch_size):
batch = data[i : i + config.online_store.write_batch_size]
count, errs = helpers.bulk(client=es, actions=self._get_bulk_documents(table.name, batch))
successes += count
if type(errs) is int:
error_count += errs
elif type(errs) is list:
errors.extend(errs)
logger.info(f"bulk write completed with {successes} successes")
if error_count:
logger.error(f"bulk write encountered {errors} errors")
if errors:
logger.error(f"bulk write return errors: {errors}")
logger.error(f"bulk write returned errors: {errors}")

def online_read(
self,
Expand All @@ -129,7 +144,7 @@ def online_read(
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
with ElasticsearchConnectionManager(config.online_store) as es:
with self._get_client(config) as es:
id_list = []
for entity in entity_keys:
for val in entity.entity_values:
Expand Down Expand Up @@ -182,7 +197,7 @@ def update(
entities_to_keep: Sequence[Entity],
partial: bool,
):
with ElasticsearchConnectionManager(config.online_store) as es:
with self._get_client(config.online_store) as es:
for fv in tables_to_delete:
resp = es.indices.exists(index=fv.name)
if resp.body:
Expand Down
18 changes: 17 additions & 1 deletion sdk/python/tests/expediagroup/test_elasticsearch_online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import pytest

from elasticsearch import Elasticsearch
from feast import FeatureView
from feast.entity import Entity
from feast.expediagroup.vectordb.elasticsearch_online_store import (
ElasticsearchConnectionManager,
ElasticsearchOnlineStore,
ElasticsearchOnlineStoreConfig,
)
Expand Down Expand Up @@ -48,6 +48,21 @@
]


class ElasticsearchConnectionManager:
def __init__(self, online_config: RepoConfig):
self.online_config = online_config
def __enter__(self):
# Connecting to Elasticsearch
self.client = Elasticsearch(
self.online_config.endpoint,
basic_auth=(self.online_config.username, self.online_config.password),
)
return self.client
def __exit__(self, exc_type, exc_value, traceback):
# Disconnecting from Elasticsearch
self.client.transport.close()


@pytest.fixture(scope="session")
def repo_config(embedded_elasticsearch):
return RepoConfig(
Expand All @@ -58,6 +73,7 @@ def repo_config(embedded_elasticsearch):
endpoint=f"http://{embedded_elasticsearch['host']}:{embedded_elasticsearch['port']}",
username=embedded_elasticsearch["username"],
password=embedded_elasticsearch["password"],
write_batch_size=5
),
offline_store=DaskOfflineStoreConfig(),
entity_key_serialization_version=2,
Expand Down
Loading