From 4e59c44e6f46427aedd75a7918c11161c0c1768f Mon Sep 17 00:00:00 2001 From: everpcpc Date: Tue, 21 Jan 2025 17:12:21 +0800 Subject: [PATCH] fix(bindings/python): __iter__ for Row and Cursor (#576) --- bindings/python/README.md | 8 +++- bindings/python/src/blocking.rs | 15 ++++++- bindings/python/src/types.rs | 41 +++++++++++-------- bindings/python/tests/cursor/steps/binding.py | 33 ++++++++++++--- 4 files changed, 73 insertions(+), 24 deletions(-) diff --git a/bindings/python/README.md b/bindings/python/README.md index 11695034..89f9eb13 100644 --- a/bindings/python/README.md +++ b/bindings/python/README.md @@ -186,9 +186,12 @@ class BlockingDatabendCursor: def close(self) -> None: ... def execute(self, operation: str, params: list[string] | tuple[string] = None) -> None | int: ... def executemany(self, operation: str, params: list[list[string] | tuple[string]]) -> None | int: ... - def fetchone(self) -> Row: ... + def fetchone(self) -> Row | None: ... def fetchmany(self, size: int = 1) -> list[Row]: ... def fetchall(self) -> list[Row]: ... + def next(self) -> Row: ... + def __iter__(self) -> Self: ... + def __next__(self) -> Row: ... ``` ### Row @@ -197,7 +200,8 @@ class BlockingDatabendCursor: class Row: def values(self) -> tuple: ... def __len__(self) -> int: ... - def __iter__(self) -> list: ... + def __iter__(self) -> Self: ... + def __next__(self) -> Value: ... def __dict__(self) -> dict: ... def __getitem__(self, key: int | str) -> any: ... ``` diff --git a/bindings/python/src/blocking.rs b/bindings/python/src/blocking.rs index 990d5cff..59d4acce 100644 --- a/bindings/python/src/blocking.rs +++ b/bindings/python/src/blocking.rs @@ -16,7 +16,7 @@ use std::collections::BTreeMap; use std::path::Path; use std::sync::Arc; -use pyo3::exceptions::{PyAttributeError, PyException}; +use pyo3::exceptions::{PyAttributeError, PyException, PyStopIteration}; use pyo3::prelude::*; use pyo3::types::{PyList, PyTuple}; use tokio::sync::Mutex; @@ -304,6 +304,19 @@ impl BlockingDatabendCursor { None => Ok(vec![]), } } + + pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + pub fn __next__(&mut self, py: Python) -> PyResult { + match self.fetchone(py)? { + Some(row) => Ok(row), + None => Err(PyStopIteration::new_err("Rows exhausted")), + } + } + pub fn next(&mut self, py: Python) -> PyResult { + self.__next__(py) + } } fn format_csv<'p>(parameters: Vec>) -> PyResult> { diff --git a/bindings/python/src/types.rs b/bindings/python/src/types.rs index 097b3ad1..35f7aef3 100644 --- a/bindings/python/src/types.rs +++ b/bindings/python/src/types.rs @@ -154,47 +154,56 @@ impl<'py> IntoPyObject<'py> for NumberValue { } #[pyclass(module = "databend_driver")] -pub struct Row(databend_driver::Row); +pub struct Row { + inner: databend_driver::Row, + idx: usize, +} impl Row { pub fn new(row: databend_driver::Row) -> Self { - Row(row) + Row { inner: row, idx: 0 } } } #[pymethods] impl Row { pub fn values<'p>(&'p self, py: Python<'p>) -> PyResult> { - let vals = self.0.values().iter().map(|v| Value(v.clone())); + let vals = self.inner.values().iter().map(|v| Value(v.clone())); let tuple = PyTuple::new(py, vals)?; Ok(tuple) } pub fn __len__(&self) -> usize { - self.0.len() + self.inner.len() } - pub fn __iter__<'p>(&'p self, py: Python<'p>) -> PyResult> { - let vals = self.0.values().iter().map(|v| Value(v.clone())); - let list = PyList::new(py, vals)?; - Ok(list.into_bound()) + pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + pub fn __next__(&mut self) -> PyResult { + if self.idx >= self.inner.len() { + return Err(PyStopIteration::new_err("Columns exhausted")); + } + let val = self.get_by_index(self.idx)?; + self.idx += 1; + Ok(val) } pub fn __dict__<'p>(&'p self, py: Python<'p>) -> PyResult> { let dict = PyDict::new(py); - let schema = self.0.schema(); - for (field, value) in schema.fields().iter().zip(self.0.values()) { + let schema = self.inner.schema(); + for (field, value) in schema.fields().iter().zip(self.inner.values()) { dict.set_item(&field.name, Value(value.clone()))?; } Ok(dict.into_bound()) } fn get_by_index(&self, idx: usize) -> PyResult { - Ok(Value(self.0.values()[idx].clone())) + Ok(Value(self.inner.values()[idx].clone())) } fn get_by_field(&self, field: &str) -> PyResult { - let schema = self.0.schema(); + let schema = self.inner.schema(); let idx = schema .fields() .iter() @@ -202,7 +211,7 @@ impl Row { .ok_or_else(|| { PyException::new_err(format!("field '{}' not found in schema", field)) })?; - Ok(Value(self.0.values()[idx].clone())) + Ok(Value(self.inner.values()[idx].clone())) } pub fn __getitem__<'p>(&'p self, key: Bound<'p, PyAny>) -> PyResult { @@ -244,9 +253,9 @@ impl RowIterator { match streamer.lock().await.next().await { Some(val) => match val { Err(e) => Err(PyException::new_err(format!("{}", e))), - Ok(ret) => Ok(Row(ret)), + Ok(ret) => Ok(Row::new(ret)), }, - None => Err(PyStopIteration::new_err("The iterator is exhausted")), + None => Err(PyStopIteration::new_err("Rows exhausted")), } }) } @@ -260,7 +269,7 @@ impl RowIterator { match streamer.lock().await.next().await { Some(val) => match val { Err(e) => Err(PyException::new_err(format!("{}", e))), - Ok(ret) => Ok(Row(ret)), + Ok(ret) => Ok(Row::new(ret)), }, None => Err(PyStopAsyncIteration::new_err("The iterator is exhausted")), } diff --git a/bindings/python/tests/cursor/steps/binding.py b/bindings/python/tests/cursor/steps/binding.py index 41225219..d4380ce6 100644 --- a/bindings/python/tests/cursor/steps/binding.py +++ b/bindings/python/tests/cursor/steps/binding.py @@ -52,8 +52,14 @@ def _(context): def _(context, input, output): context.cursor.execute(f"SELECT '{input}'") row = context.cursor.fetchone() + + # getitem assert output == row[0], f"output: {output}" + # iter + val = next(row) + assert val == output, f"val: {val}" + @then("Select types should be expected native types") async def _(context): @@ -127,16 +133,33 @@ def _(context): (-3, 3, 3.0, '3', '2', '2016-04-04', '2016-04-04 11:30:00') """ ) - context.cursor.execute("SELECT * FROM test") - rows = context.cursor.fetchall() - ret = [] - for row in rows: - ret.append(row.values()) expected = [ (-1, 1, 1.0, "1", "1", date(2011, 3, 6), datetime(2011, 3, 6, 6, 20)), (-2, 2, 2.0, "2", "2", date(2012, 5, 31), datetime(2012, 5, 31, 11, 20)), (-3, 3, 3.0, "3", "2", date(2016, 4, 4), datetime(2016, 4, 4, 11, 30)), ] + + # fetchall + context.cursor.execute("SELECT * FROM test") + rows = context.cursor.fetchall() + ret = [] + for row in rows: + ret.append(row.values()) + assert ret == expected, f"ret: {ret}" + + # fetchmany + context.cursor.execute("SELECT * FROM test") + rows = context.cursor.fetchmany(3) + ret = [] + for row in rows: + ret.append(row.values()) + assert ret == expected, f"ret: {ret}" + + # iter + context.cursor.execute("SELECT * FROM test") + ret = [] + for row in context.cursor: + ret.append(row.values()) assert ret == expected, f"ret: {ret}"