diff --git a/fakesnow/arrow.py b/fakesnow/arrow.py index 24d0cd5..dcc9e3e 100644 --- a/fakesnow/arrow.py +++ b/fakesnow/arrow.py @@ -1,37 +1,35 @@ -from typing import Any - import pyarrow as pa +from fakesnow.types import ColumnInfo + + +def with_sf_metadata(schema: pa.Schema, rowtype: list[ColumnInfo]) -> pa.Schema: + # expected by the snowflake connector + # uses rowtype to populate metadata, rather than the arrow schema type, for consistency with + # rowtype returned in the response + + assert len(schema) == len(rowtype), f"schema and rowtype must be same length but f{len(schema)=} f{len(rowtype)=}" -def with_sf_metadata(schema: pa.Schema) -> pa.Schema: # see https://github.com/snowflakedb/snowflake-connector-python/blob/e9393a6/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.cpp#L32 # and https://github.com/snowflakedb/snowflake-connector-python/blob/e9393a6/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.cpp#L10 - fms = [] - for i, t in enumerate(schema.types): - f = schema.field(i) - - # TODO: precision, scale, charLength etc. for all types - if t == pa.bool_(): - fm = f.with_metadata({"logicalType": "BOOLEAN"}) - elif t == pa.int64(): - # scale and precision required, see here - # https://github.com/snowflakedb/snowflake-connector-python/blob/416ff57/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.cpp#L147 - fm = f.with_metadata({"logicalType": "FIXED", "precision": "38", "scale": "0"}) - elif t == pa.float64(): - fm = f.with_metadata({"logicalType": "REAL"}) - elif isinstance(t, pa.Decimal128Type): - fm = f.with_metadata({"logicalType": "FIXED", "precision": str(t.precision), "scale": str(t.scale)}) - elif t == pa.string(): - # TODO: set charLength to size of column - fm = f.with_metadata({"logicalType": "TEXT", "charLength": "16777216"}) - else: - raise NotImplementedError(f"Unsupported Arrow type: {t}") - fms.append(fm) + fms = [ + schema.field(i).with_metadata( + { + "logicalType": c["type"].upper(), + # required for FIXED type see + # https://github.com/snowflakedb/snowflake-connector-python/blob/416ff57/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.cpp#L147 + "precision": str(c["precision"] or 38), + "scale": str(c["scale"] or 0), + "charLength": str(c["length"] or 0), + } + ) + for i, c in enumerate(rowtype) + ] return pa.schema(fms) -def to_ipc(table: pa.Table) -> pa.Buffer: +def to_ipc(table: pa.Table, rowtype: list[ColumnInfo]) -> pa.Buffer: batches = table.to_batches() if len(batches) != 1: raise NotImplementedError(f"{len(batches)} batches") @@ -39,29 +37,7 @@ def to_ipc(table: pa.Table) -> pa.Buffer: sink = pa.BufferOutputStream() - with pa.ipc.new_stream(sink, with_sf_metadata(table.schema)) as writer: + with pa.ipc.new_stream(sink, with_sf_metadata(table.schema, rowtype)) as writer: writer.write_batch(batch) return sink.getvalue() - - -# TODO: should this be derived before with_schema? -def to_rowtype(schema: pa.Schema) -> list[dict[str, Any]]: - return [ - { - "name": f.name, - # TODO - # "database": "", - # "schema": "", - # "table": "", - "nullable": f.nullable, - "type": f.metadata.get(b"logicalType").decode("utf-8").lower(), # type: ignore - # TODO - # "byteLength": 20, - "length": int(f.metadata.get(b"charLength")) if f.metadata.get(b"charLength") else None, # type: ignore - "scale": int(f.metadata.get(b"scale")) if f.metadata.get(b"scale") else None, # type: ignore - "precision": int(f.metadata.get(b"precision")) if f.metadata.get(b"precision") else None, # type: ignore - "collation": None, - } - for f in schema - ] diff --git a/fakesnow/server.py b/fakesnow/server.py index 4f386a2..2391f62 100644 --- a/fakesnow/server.py +++ b/fakesnow/server.py @@ -48,14 +48,14 @@ async def query_request(request: Request) -> JSONResponse: # only a single sql statement is sent at a time by the python snowflake connector cur = await run_in_threadpool(conn.cursor().execute, sql_text) + rowtype = describe_as_rowtype(cur._describe_last_sql()) # noqa: SLF001 + if cur._arrow_table: # noqa: SLF001 - batch_bytes = to_ipc(cur._arrow_table) # noqa: SLF001 + batch_bytes = to_ipc(cur._arrow_table, rowtype) # noqa: SLF001 rowset_b64 = b64encode(batch_bytes).decode("utf-8") else: rowset_b64 = "" - rowtype = describe_as_rowtype(cur._describe_last_sql()) # noqa: SLF001 - return JSONResponse( { "data": { diff --git a/tests/test_arrow.py b/tests/test_arrow.py index 69d6e40..78619f7 100644 --- a/tests/test_arrow.py +++ b/tests/test_arrow.py @@ -4,15 +4,30 @@ import pyarrow as pa from fakesnow.arrow import to_ipc, with_sf_metadata +from fakesnow.types import ColumnInfo, describe_as_rowtype + + +def rowtype(types: list[str]) -> list[ColumnInfo]: + return describe_as_rowtype([("test", typ, None, None, None, None) for typ in types]) def test_with_sf_metadata() -> None: # see https://arrow.apache.org/docs/python/api/datatypes.html - def f(t: pa.DataType) -> dict: - return with_sf_metadata(pa.schema([pa.field(str(t), t)])).field(0).metadata + def f(t: pa.DataType, rowtype: list[ColumnInfo]) -> dict: + return with_sf_metadata(pa.schema([pa.field(str(t), t)]), rowtype).field(0).metadata - assert f(pa.string()) == {b"logicalType": b"TEXT", b"charLength": b"16777216"} - assert f(pa.decimal128(10, 2)) == {b"logicalType": b"FIXED", b"precision": b"10", b"scale": b"2"} + assert f(pa.string(), rowtype(["VARCHAR"])) == { + b"logicalType": b"TEXT", + b"precision": b"38", + b"scale": b"0", + b"charLength": b"16777216", + } + assert f(pa.decimal128(10, 2), rowtype(["DECIMAL(10,2)"])) == { + b"logicalType": b"FIXED", + b"precision": b"10", + b"scale": b"2", + b"charLength": b"0", + } def test_ipc_writes_sf_metadata() -> None: @@ -23,7 +38,7 @@ def test_ipc_writes_sf_metadata() -> None: ) table = pa.Table.from_pandas(df) - table_bytes = to_ipc(table) + table_bytes = to_ipc(table, rowtype(["VARCHAR"])) batch = next(iter(pa.ipc.open_stream(table_bytes))) @@ -31,6 +46,8 @@ def test_ipc_writes_sf_metadata() -> None: assert pa.table(batch) == table assert batch.schema.field(0).metadata == { b"logicalType": b"TEXT", + b"precision": b"38", + b"scale": b"0", b"charLength": b"16777216", }