Skip to content

Commit

Permalink
fix(bindings/python): __iter__ for Row and Cursor (#576)
Browse files Browse the repository at this point in the history
  • Loading branch information
everpcpc authored Jan 21, 2025
1 parent 53b410e commit 4e59c44
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 24 deletions.
8 changes: 6 additions & 2 deletions bindings/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,12 @@ class BlockingDatabendCursor:
def close(self) -> None: ...
def execute(self, operation: str, params: list[string] | tuple[string] = None) -> None | int: ...
def executemany(self, operation: str, params: list[list[string] | tuple[string]]) -> None | int: ...
def fetchone(self) -> Row: ...
def fetchone(self) -> Row | None: ...
def fetchmany(self, size: int = 1) -> list[Row]: ...
def fetchall(self) -> list[Row]: ...
def next(self) -> Row: ...
def __iter__(self) -> Self: ...
def __next__(self) -> Row: ...
```

### Row
Expand All @@ -197,7 +200,8 @@ class BlockingDatabendCursor:
class Row:
def values(self) -> tuple: ...
def __len__(self) -> int: ...
def __iter__(self) -> list: ...
def __iter__(self) -> Self: ...
def __next__(self) -> Value: ...
def __dict__(self) -> dict: ...
def __getitem__(self, key: int | str) -> any: ...
```
Expand Down
15 changes: 14 additions & 1 deletion bindings/python/src/blocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::collections::BTreeMap;
use std::path::Path;
use std::sync::Arc;

use pyo3::exceptions::{PyAttributeError, PyException};
use pyo3::exceptions::{PyAttributeError, PyException, PyStopIteration};
use pyo3::prelude::*;
use pyo3::types::{PyList, PyTuple};
use tokio::sync::Mutex;
Expand Down Expand Up @@ -304,6 +304,19 @@ impl BlockingDatabendCursor {
None => Ok(vec![]),
}
}

pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
pub fn __next__(&mut self, py: Python) -> PyResult<Row> {
match self.fetchone(py)? {
Some(row) => Ok(row),
None => Err(PyStopIteration::new_err("Rows exhausted")),
}
}
pub fn next(&mut self, py: Python) -> PyResult<Row> {
self.__next__(py)
}
}

fn format_csv<'p>(parameters: Vec<Bound<'p, PyAny>>) -> PyResult<Vec<u8>> {
Expand Down
41 changes: 25 additions & 16 deletions bindings/python/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,55 +154,64 @@ impl<'py> IntoPyObject<'py> for NumberValue {
}

#[pyclass(module = "databend_driver")]
pub struct Row(databend_driver::Row);
pub struct Row {
inner: databend_driver::Row,
idx: usize,
}

impl Row {
pub fn new(row: databend_driver::Row) -> Self {
Row(row)
Row { inner: row, idx: 0 }
}
}

#[pymethods]
impl Row {
pub fn values<'p>(&'p self, py: Python<'p>) -> PyResult<Bound<'p, PyTuple>> {
let vals = self.0.values().iter().map(|v| Value(v.clone()));
let vals = self.inner.values().iter().map(|v| Value(v.clone()));
let tuple = PyTuple::new(py, vals)?;
Ok(tuple)
}

pub fn __len__(&self) -> usize {
self.0.len()
self.inner.len()
}

pub fn __iter__<'p>(&'p self, py: Python<'p>) -> PyResult<Bound<'p, PyList>> {
let vals = self.0.values().iter().map(|v| Value(v.clone()));
let list = PyList::new(py, vals)?;
Ok(list.into_bound())
pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
pub fn __next__(&mut self) -> PyResult<Value> {
if self.idx >= self.inner.len() {
return Err(PyStopIteration::new_err("Columns exhausted"));
}
let val = self.get_by_index(self.idx)?;
self.idx += 1;
Ok(val)
}

pub fn __dict__<'p>(&'p self, py: Python<'p>) -> PyResult<Bound<'p, PyDict>> {
let dict = PyDict::new(py);
let schema = self.0.schema();
for (field, value) in schema.fields().iter().zip(self.0.values()) {
let schema = self.inner.schema();
for (field, value) in schema.fields().iter().zip(self.inner.values()) {
dict.set_item(&field.name, Value(value.clone()))?;
}
Ok(dict.into_bound())
}

fn get_by_index(&self, idx: usize) -> PyResult<Value> {
Ok(Value(self.0.values()[idx].clone()))
Ok(Value(self.inner.values()[idx].clone()))
}

fn get_by_field(&self, field: &str) -> PyResult<Value> {
let schema = self.0.schema();
let schema = self.inner.schema();
let idx = schema
.fields()
.iter()
.position(|f| f.name == field)
.ok_or_else(|| {
PyException::new_err(format!("field '{}' not found in schema", field))
})?;
Ok(Value(self.0.values()[idx].clone()))
Ok(Value(self.inner.values()[idx].clone()))
}

pub fn __getitem__<'p>(&'p self, key: Bound<'p, PyAny>) -> PyResult<Value> {
Expand Down Expand Up @@ -244,9 +253,9 @@ impl RowIterator {
match streamer.lock().await.next().await {
Some(val) => match val {
Err(e) => Err(PyException::new_err(format!("{}", e))),
Ok(ret) => Ok(Row(ret)),
Ok(ret) => Ok(Row::new(ret)),
},
None => Err(PyStopIteration::new_err("The iterator is exhausted")),
None => Err(PyStopIteration::new_err("Rows exhausted")),
}
})
}
Expand All @@ -260,7 +269,7 @@ impl RowIterator {
match streamer.lock().await.next().await {
Some(val) => match val {
Err(e) => Err(PyException::new_err(format!("{}", e))),
Ok(ret) => Ok(Row(ret)),
Ok(ret) => Ok(Row::new(ret)),
},
None => Err(PyStopAsyncIteration::new_err("The iterator is exhausted")),
}
Expand Down
33 changes: 28 additions & 5 deletions bindings/python/tests/cursor/steps/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,14 @@ def _(context):
def _(context, input, output):
context.cursor.execute(f"SELECT '{input}'")
row = context.cursor.fetchone()

# getitem
assert output == row[0], f"output: {output}"

# iter
val = next(row)
assert val == output, f"val: {val}"


@then("Select types should be expected native types")
async def _(context):
Expand Down Expand Up @@ -127,16 +133,33 @@ def _(context):
(-3, 3, 3.0, '3', '2', '2016-04-04', '2016-04-04 11:30:00')
"""
)
context.cursor.execute("SELECT * FROM test")
rows = context.cursor.fetchall()
ret = []
for row in rows:
ret.append(row.values())
expected = [
(-1, 1, 1.0, "1", "1", date(2011, 3, 6), datetime(2011, 3, 6, 6, 20)),
(-2, 2, 2.0, "2", "2", date(2012, 5, 31), datetime(2012, 5, 31, 11, 20)),
(-3, 3, 3.0, "3", "2", date(2016, 4, 4), datetime(2016, 4, 4, 11, 30)),
]

# fetchall
context.cursor.execute("SELECT * FROM test")
rows = context.cursor.fetchall()
ret = []
for row in rows:
ret.append(row.values())
assert ret == expected, f"ret: {ret}"

# fetchmany
context.cursor.execute("SELECT * FROM test")
rows = context.cursor.fetchmany(3)
ret = []
for row in rows:
ret.append(row.values())
assert ret == expected, f"ret: {ret}"

# iter
context.cursor.execute("SELECT * FROM test")
ret = []
for row in context.cursor:
ret.append(row.values())
assert ret == expected, f"ret: {ret}"


Expand Down

0 comments on commit 4e59c44

Please sign in to comment.