From 7f813fb35671dfb25a1c6d91a4649e0550bbd513 Mon Sep 17 00:00:00 2001 From: everpcpc Date: Mon, 20 Jan 2025 16:21:19 +0800 Subject: [PATCH] feat(bindings/python): add cursor method fetchmany (#574) --- bindings/python/README.md | 1 + bindings/python/src/blocking.rs | 20 ++++++++++ bindings/python/tests/cursor/steps/binding.py | 39 ++++++++++++------- 3 files changed, 46 insertions(+), 14 deletions(-) diff --git a/bindings/python/README.md b/bindings/python/README.md index 24d78637..11695034 100644 --- a/bindings/python/README.md +++ b/bindings/python/README.md @@ -187,6 +187,7 @@ class BlockingDatabendCursor: def execute(self, operation: str, params: list[string] | tuple[string] = None) -> None | int: ... def executemany(self, operation: str, params: list[list[string] | tuple[string]]) -> None | int: ... def fetchone(self) -> Row: ... + def fetchmany(self, size: int = 1) -> list[Row]: ... def fetchall(self) -> list[Row]: ... ``` diff --git a/bindings/python/src/blocking.rs b/bindings/python/src/blocking.rs index 1b9be255..990d5cff 100644 --- a/bindings/python/src/blocking.rs +++ b/bindings/python/src/blocking.rs @@ -264,6 +264,26 @@ impl BlockingDatabendCursor { } } + #[pyo3(signature = (size=1))] + pub fn fetchmany(&mut self, py: Python, size: Option) -> PyResult> { + let mut result = self.buffer.drain(..).collect::>(); + if let Some(ref rows) = self.rows { + let size = size.unwrap_or(1); + while result.len() < size { + let row = wait_for_future(py, async move { + let mut rows = rows.lock().await; + rows.next().await.transpose().map_err(DriverError::new) + })?; + if let Some(row) = row { + result.push(Row::new(row)); + } else { + break; + } + } + } + Ok(result) + } + pub fn fetchall(&mut self, py: Python) -> PyResult> { let mut result = self.buffer.drain(..).collect::>(); match self.rows.take() { diff --git a/bindings/python/tests/cursor/steps/binding.py b/bindings/python/tests/cursor/steps/binding.py index 2403a576..41225219 100644 --- a/bindings/python/tests/cursor/steps/binding.py +++ b/bindings/python/tests/cursor/steps/binding.py @@ -60,49 +60,60 @@ async def _(context): # Binary context.cursor.execute("select to_binary('xyz')") row = context.cursor.fetchone() - assert row[0] == b"xyz", f"Binary: {row.values()}" + expected = (b"xyz",) + assert row.values() == expected, f"Binary: {row.values()}" # Interval context.cursor.execute("select to_interval('1 days')") row = context.cursor.fetchone() - assert row.values() == (timedelta(1),), f"Interval: {row.values()}" + expected = (timedelta(1),) + assert row.values() == expected, f"Interval: {row.values()}" # Decimal context.cursor.execute("SELECT 15.7563::Decimal(8,4), 2.0+3.0") row = context.cursor.fetchone() - assert row.values() == ( - Decimal("15.7563"), - Decimal("5.0"), - ), f"Decimal: {row.values()}" + expected = (Decimal("15.7563"), Decimal("5.0")) + assert row.values() == expected, f"Decimal: {row.values()}" # Array context.cursor.execute("select [10::Decimal(15,2), 1.1+2.3]") row = context.cursor.fetchone() - assert row.values() == ([Decimal("10.00"), Decimal("3.40")],), ( - f"Array: {row.values()}" - ) + expected = [Decimal("10.00"), Decimal("3.40")] + assert row.values() == expected, f"Array: {row.values()}" # Map context.cursor.execute("select {'xx':to_date('2020-01-01')}") row = context.cursor.fetchone() - assert row.values() == ({"xx": date(2020, 1, 1)},), f"Map: {row.values()}" + expected = {"xx": date(2020, 1, 1)} + assert row.values() == expected, f"Map: {row.values()}" # Tuple context.cursor.execute("select (10, '20', to_datetime('2024-04-16 12:34:56.789'))") row = context.cursor.fetchone() - assert row.values() == ((10, "20", datetime(2024, 4, 16, 12, 34, 56, 789000)),), ( - f"Tuple: {row.values()}" + expected = ( + 10, + "20", + datetime(2024, 4, 16, 12, 34, 56, 789000), ) + assert row.values() == expected, f"Tuple: {row.values()}" @then("Select numbers should iterate all rows") def _(context): context.cursor.execute("SELECT number FROM numbers(5)") - rows = context.cursor.fetchall() + + rows = context.cursor.fetchmany(3) + ret = [] + for row in rows: + ret.append(row[0]) + expected = [0, 1, 2] + assert ret == expected, f"ret: {ret}" + + rows = context.cursor.fetchmany(3) ret = [] for row in rows: ret.append(row[0]) - expected = [0, 1, 2, 3, 4] + expected = [3, 4] assert ret == expected, f"ret: {ret}"