Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up bulk processor flush and cut release #187

Merged
merged 8 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "xata"
version = "1.2.1"
version = "1.2.2"
description = "Python SDK for Xata.io"
authors = ["Xata <[email protected]>"]
license = "Apache-2.0"
Expand Down
94 changes: 65 additions & 29 deletions tests/integration-tests/helpers_bulkprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,44 @@
class TestHelpersBulkProcessor(object):
def setup_class(self):
self.db_name = utils.get_db_name()
self.branch_name = "main"
self.client = XataClient(db_name=self.db_name, branch_name=self.branch_name)
self.client = XataClient(db_name=self.db_name)
self.fake = Faker()

assert self.client.databases().create(self.db_name).is_success()
assert self.client.table().create("Posts").is_success()
assert self.client.table().create("Users").is_success()

# create schema
assert self.client.table().set_schema(
"Posts",
{
"columns": [
{"name": "title", "type": "string"},
{"name": "text", "type": "text"},
]
},
).is_success()
assert self.client.table().set_schema(
"Users",
{
"columns": [
{"name": "username", "type": "string"},
{"name": "email", "type": "string"},
]
},
).is_success()
assert (
self.client.table()
.set_schema(
"Posts",
{
"columns": [
{"name": "title", "type": "string"},
{"name": "text", "type": "text"},
]
},
)
.is_success()
)
assert (
self.client.table()
.set_schema(
"Users",
{
"columns": [
{"name": "username", "type": "string"},
{"name": "email", "type": "string"},
]
},
)
.is_success()
)

def teardown_class(self):
assert self.client.databases().delete(self.db_name).is_success()
# assert self.client.databases().delete(self.db_name).is_success()
pass

@pytest.fixture
def record(self) -> dict:
Expand All @@ -70,7 +78,7 @@ def _get_record(self) -> dict:
"title": self.fake.company(),
"text": self.fake.text(),
}

def _get_user(self) -> dict:
return {
"username": self.fake.name(),
Expand All @@ -81,20 +89,22 @@ def test_bulk_insert_records(self, record: dict):
bp = BulkProcessor(
self.client,
thread_pool_size=1,
batch_size=43,
)
bp.put_records("Posts", [self._get_record() for x in range(42)])
bp.flush_queue()

r = self.client.data().summarize("Posts", {"summaries": {"proof": {"count": "*"}}})
assert r.is_success()
assert "summaries" in r
assert r["summaries"][0]["proof"] == 42

stats = bp.get_stats()
assert stats["total"] == 42
assert stats["queue"] == 0
assert stats["failed_batches"] == 0
assert stats["tables"]["Posts"] == 42
assert stats["total_batches"] == 1

r = self.client.data().summarize("Posts", {"summaries": {"proof": {"count": "*"}}})
assert r.is_success()
assert "summaries" in r
assert r["summaries"][0]["proof"] == stats["total"]

def test_flush_queue(self):
assert self.client.sql().query('DELETE FROM "Posts" WHERE 1 = 1').is_success()
Expand All @@ -112,15 +122,40 @@ def test_flush_queue(self):
assert r.is_success()
assert "summaries" in r
assert r["summaries"][0]["proof"] == 1000

stats = bp.get_stats()
assert stats["total"] == 1000
assert stats["queue"] == 0
assert stats["failed_batches"] == 0
assert stats["total_batches"] == 20
assert stats["tables"]["Posts"] == 1000

def test_flush_queue_many_threads(self):
assert self.client.sql().query('DELETE FROM "Users" WHERE 1 = 1').is_success()

bp = BulkProcessor(
self.client,
thread_pool_size=8,
batch_size=10,
)
bp.put_records("Users", [self._get_user() for x in range(750)])
bp.flush_queue()

r = self.client.data().summarize("Users", {"summaries": {"proof": {"count": "*"}}})
assert r.is_success()
assert "summaries" in r
assert r["summaries"][0]["proof"] == 750

stats = bp.get_stats()
assert stats["total"] == 750
assert stats["queue"] == 0
assert stats["failed_batches"] == 0
assert stats["total_batches"] == 75
assert stats["tables"]["Users"] == 750

def test_multiple_tables(self):
assert self.client.sql().query('DELETE FROM "Posts" WHERE 1 = 1').is_success()
assert self.client.sql().query('DELETE FROM "Users" WHERE 1 = 1').is_success()

bp = BulkProcessor(
self.client,
Expand All @@ -141,10 +176,11 @@ def test_multiple_tables(self):
assert r.is_success()
assert "summaries" in r
assert r["summaries"][0]["proof"] == 33 * 7

stats = bp.get_stats()
assert stats["queue"] == 0
assert stats["failed_batches"] == 0
assert stats["total_batches"] == 14
assert stats["tables"]["Posts"] == 33 * 9
assert stats["tables"]["Users"] == 33 * 7
assert stats["total"] == stats["tables"]["Posts"] + stats["tables"]["Users"]
6 changes: 3 additions & 3 deletions tests/unit-tests/helpers_bulk_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ def test_bulk_processor_init(self):

with pytest.raises(Exception) as e:
BulkProcessor(client, batch_size=-1)
assert str(e.value) == "batch size can not be less than one, default: 25"
assert str(e.value) == "batch size can not be less than one, default: 50"

with pytest.raises(Exception) as e:
BulkProcessor(client, flush_interval=-1)
assert str(e.value) == "flush interval can not be negative, default: 5.000000"
assert str(e.value) == "flush interval can not be negative, default: 2.000000"

with pytest.raises(Exception) as e:
BulkProcessor(client, processing_timeout=-1)
assert str(e.value) == "processing timeout can not be negative, default: 0.025000"
assert str(e.value) == "processing timeout can not be negative, default: 0.050000"

def test_bulk_processor_stats(self):
client = XataClient(api_key="api_key", workspace_id="ws_id")
Expand Down
2 changes: 1 addition & 1 deletion xata/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

# TODO this is a manual task, to keep in sync with pyproject.toml
# could/should be automated to keep in sync
__version__ = "1.2.1"
__version__ = "1.2.2"

PERSONAL_API_KEY_LOCATION = "~/.config/xata/key"
DEFAULT_DATA_PLANE_DOMAIN = "xata.sh"
Expand Down
91 changes: 50 additions & 41 deletions xata/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
from .client import XataClient

BP_DEFAULT_THREAD_POOL_SIZE = 4
BP_DEFAULT_BATCH_SIZE = 25
BP_DEFAULT_FLUSH_INTERVAL = 5
BP_DEFAULT_PROCESSING_TIMEOUT = 0.025
BP_DEFAULT_BATCH_SIZE = 50
BP_DEFAULT_FLUSH_INTERVAL = 2
BP_DEFAULT_PROCESSING_TIMEOUT = 0.05
BP_DEFAULT_THROW_EXCEPTION = False
BP_VERSION = "0.3.0"
BP_VERSION = "0.3.1"
TRX_MAX_OPERATIONS = 1000
TRX_VERSION = "0.1.0"
TRX_BACKOFF = 0.1
Expand Down Expand Up @@ -85,10 +85,13 @@ def __init__(
self.flush_interval = flush_interval
self.failed_batches_queue = []
self.throw_exception = throw_exception
self.stats = {"total": 0, "queue": 0, "failed_batches": 0, "tables": {}}

self.stats = {"total": 0, "queue": 0, "failed_batches": 0, "total_batches": 0, "tables": {}}
self.stats_lock = Lock()
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")

self.thread_workers = []
self.worker_active = True
self.records = self.Records(self.batch_size, self.flush_interval, self.logger)

for i in range(thread_pool_size):
Expand All @@ -110,12 +113,16 @@ def process(self, id: int):
self.processing_timeout,
)
)
while True:
while self.worker_active:
sleep_backoff = 5 # slow down if no records exist
time.sleep(self.processing_timeout * sleep_backoff)

# process
batch = self.records.next_batch()
if "table" in batch and len(batch["records"]) > 0:
try:
r = self.client.records().bulk_insert(batch["table"], {"records": batch["records"]})
if r.status_code != 200:
if not r.is_success():
self.logger.error(
"thread #%d: unable to process batch for table '%s', with error: %d - %s"
% (id, batch["table"], r.status_code, r.json())
Expand All @@ -137,14 +144,16 @@ def process(self, id: int):
"thread #%d: pushed a batch of %d records to table %s"
% (id, len(batch["records"]), batch["table"])
)
# with self.stats_lock:
self.stats["total"] += len(batch["records"])
self.stats["queue"] = self.records.size()
if batch["table"] not in self.stats["tables"]:
self.stats["tables"][batch["table"]] = 0
self.stats["tables"][batch["table"]] += len(batch["records"])
self.stats["total_batches"] += 1
except Exception as exc:
logging.error("thread #%d: %s" % (id, exc))
time.sleep(self.processing_timeout)
sleep_backoff = 1 # keep velocity

def put_record(self, table_name: str, record: dict):
"""
Expand Down Expand Up @@ -179,26 +188,38 @@ def get_stats(self):
"""
return self.stats

def get_queue_size(self) -> int:
with self.stats_lock:
return self.stats["queue"]

def flush_queue(self):
"""
Flush all records from the queue.
Flush all records from the queue. Call this as you close the ingestion operation

https://github.com/xataio/xata-py/issues/184
"""
self.logger.debug("flushing queue with %d records .." % (self.records.size()))

# force flush the records queue and shorten the processing times
self.records.force_queue_flush()
self.processing_timeout = 0.001
wait = 0.005 * len(self.thread_workers)

while self.records.size() > 0:
self.logger.debug("flushing queue with %d records." % self.stats["queue"])
time.sleep(wait)
# back off to wait for all threads run at least once
if self.records.size() == 0 or self.get_queue_size() == 0:
time.sleep(self.processing_timeout * len(self.thread_workers))

# Last poor mans check if queue is fully flushed
if self.records.size() > 0 or self.stats["queue"] > 0:
self.logger.debug("one more flush interval necessary with queue at %d records." % self.stats["queue"])
time.sleep(wait)
# ensure the full records queue is flushed first
while self.records.size() != 0:
time.sleep(self.processing_timeout * len(self.thread_workers))
self.logger.debug("flushing %d records to processing queue." % self.records.size())

# let's make really sure the queue is empty
while self.get_queue_size() > 0:
time.sleep(self.processing_timeout * len(self.thread_workers))
self.logger.debug("flushing processor queue with %d records." % self.stats["queue"])

self.worker_active = False
for worker in self.thread_workers:
worker.join()

class Records(object):
"""
Expand All @@ -212,7 +233,6 @@ def __init__(self, batch_size: int, flush_interval: int, logger):
"""
self.batch_size = batch_size
self.flush_interval = flush_interval
self.force_flush = False
self.logger = logger

self.store = dict()
Expand All @@ -224,10 +244,8 @@ def force_queue_flush(self):
Force next batch to be available
https://github.com/xataio/xata-py/issues/184
"""
with self.lock:
self.force_flush = True
self.flush_interval = 0.001
self.batch_size = 1
# push for immediate flushes
self.flush_interval = 0

def put(self, table_name: str, records: list[dict]):
"""
Expand All @@ -250,38 +268,30 @@ def next_batch(self) -> dict:

:returns dict
"""
if self.size() == 0:
return {}
table_name = ""
with self.lock:
names = list(self.store.keys())
if len(names) == 0:
return {}

self.store_ptr += 1
if len(names) <= self.store_ptr:
self.store_ptr = 0
table_name = names[self.store_ptr]

rs = []
if self.length(table_name) == 0:
return {"table": table_name, "records": rs}

with self.store[table_name]["lock"]:
# flush interval exceeded
time_elapsed = time.time() - self.store[table_name]["flushed"]
flush_needed = time_elapsed > self.flush_interval
if flush_needed and len(self.store[table_name]["records"]) > 0:
self.logger.debug(
"flushing table '%s' with %d records after interval %s > %d"
% (
table_name,
len(self.store[table_name]["records"]),
time_elapsed,
self.flush_interval,
)
)
flush_needed = time_elapsed >= self.flush_interval
# force flush table, batch size reached or timer exceeded
if self.force_flush or len(self.store[table_name]["records"]) >= self.batch_size or flush_needed:
if len(self.store[table_name]["records"]) >= self.batch_size or flush_needed:
self.store[table_name]["flushed"] = time.time()
rs = self.store[table_name]["records"][0 : self.batch_size]
del self.store[table_name]["records"][0 : self.batch_size]
return {"table": table_name, "records": rs}
return {"table": table_name, "records": rs}

def length(self, table_name: str) -> int:
"""
Expand All @@ -296,8 +306,7 @@ def size(self) -> int:
"""
Get total size of stored records
"""
with self.lock:
return sum([len(self.store[n]["records"]) for n in self.store.keys()])
return sum([self.length(n) for n in self.store.keys()])


def to_rfc339(dt: datetime, tz=timezone.utc) -> str:
Expand Down
Loading