Skip to content

Commit

Permalink
refactor: use rowtype for sf metadata in arrow schema
Browse files Browse the repository at this point in the history
  • Loading branch information
tekumara committed Aug 25, 2024
1 parent b967b69 commit e98d227
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 56 deletions.
72 changes: 24 additions & 48 deletions fakesnow/arrow.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,43 @@
from typing import Any

import pyarrow as pa

from fakesnow.types import ColumnInfo


def with_sf_metadata(schema: pa.Schema, rowtype: list[ColumnInfo]) -> pa.Schema:
# expected by the snowflake connector
# uses rowtype to populate metadata, rather than the arrow schema type, for consistency with
# rowtype returned in the response

assert len(schema) == len(rowtype), f"schema and rowtype must be same length but f{len(schema)=} f{len(rowtype)=}"

def with_sf_metadata(schema: pa.Schema) -> pa.Schema:
# see https://github.com/snowflakedb/snowflake-connector-python/blob/e9393a6/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.cpp#L32
# and https://github.com/snowflakedb/snowflake-connector-python/blob/e9393a6/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.cpp#L10
fms = []
for i, t in enumerate(schema.types):
f = schema.field(i)

# TODO: precision, scale, charLength etc. for all types

if t == pa.bool_():
fm = f.with_metadata({"logicalType": "BOOLEAN"})
elif t == pa.int64():
# scale and precision required, see here
# https://github.com/snowflakedb/snowflake-connector-python/blob/416ff57/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.cpp#L147
fm = f.with_metadata({"logicalType": "FIXED", "precision": "38", "scale": "0"})
elif t == pa.float64():
fm = f.with_metadata({"logicalType": "REAL"})
elif isinstance(t, pa.Decimal128Type):
fm = f.with_metadata({"logicalType": "FIXED", "precision": str(t.precision), "scale": str(t.scale)})
elif t == pa.string():
# TODO: set charLength to size of column
fm = f.with_metadata({"logicalType": "TEXT", "charLength": "16777216"})
else:
raise NotImplementedError(f"Unsupported Arrow type: {t}")
fms.append(fm)
fms = [
schema.field(i).with_metadata(
{
"logicalType": c["type"].upper(),
# required for FIXED type see
# https://github.com/snowflakedb/snowflake-connector-python/blob/416ff57/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.cpp#L147
"precision": str(c["precision"] or 38),
"scale": str(c["scale"] or 0),
"charLength": str(c["length"] or 0),
}
)
for i, c in enumerate(rowtype)
]
return pa.schema(fms)


def to_ipc(table: pa.Table) -> pa.Buffer:
def to_ipc(table: pa.Table, rowtype: list[ColumnInfo]) -> pa.Buffer:
batches = table.to_batches()
if len(batches) != 1:
raise NotImplementedError(f"{len(batches)} batches")
batch = batches[0]

sink = pa.BufferOutputStream()

with pa.ipc.new_stream(sink, with_sf_metadata(table.schema)) as writer:
with pa.ipc.new_stream(sink, with_sf_metadata(table.schema, rowtype)) as writer:
writer.write_batch(batch)

return sink.getvalue()


# TODO: should this be derived before with_schema?
def to_rowtype(schema: pa.Schema) -> list[dict[str, Any]]:
return [
{
"name": f.name,
# TODO
# "database": "",
# "schema": "",
# "table": "",
"nullable": f.nullable,
"type": f.metadata.get(b"logicalType").decode("utf-8").lower(), # type: ignore
# TODO
# "byteLength": 20,
"length": int(f.metadata.get(b"charLength")) if f.metadata.get(b"charLength") else None, # type: ignore
"scale": int(f.metadata.get(b"scale")) if f.metadata.get(b"scale") else None, # type: ignore
"precision": int(f.metadata.get(b"precision")) if f.metadata.get(b"precision") else None, # type: ignore
"collation": None,
}
for f in schema
]
6 changes: 3 additions & 3 deletions fakesnow/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ async def query_request(request: Request) -> JSONResponse:
# only a single sql statement is sent at a time by the python snowflake connector
cur = await run_in_threadpool(conn.cursor().execute, sql_text)

rowtype = describe_as_rowtype(cur._describe_last_sql()) # noqa: SLF001

if cur._arrow_table: # noqa: SLF001
batch_bytes = to_ipc(cur._arrow_table) # noqa: SLF001
batch_bytes = to_ipc(cur._arrow_table, rowtype) # noqa: SLF001
rowset_b64 = b64encode(batch_bytes).decode("utf-8")
else:
rowset_b64 = ""

rowtype = describe_as_rowtype(cur._describe_last_sql()) # noqa: SLF001

return JSONResponse(
{
"data": {
Expand Down
27 changes: 22 additions & 5 deletions tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,30 @@
import pyarrow as pa

from fakesnow.arrow import to_ipc, with_sf_metadata
from fakesnow.types import ColumnInfo, describe_as_rowtype


def rowtype(types: list[str]) -> list[ColumnInfo]:
return describe_as_rowtype([("test", typ, None, None, None, None) for typ in types])


def test_with_sf_metadata() -> None:
# see https://arrow.apache.org/docs/python/api/datatypes.html
def f(t: pa.DataType) -> dict:
return with_sf_metadata(pa.schema([pa.field(str(t), t)])).field(0).metadata
def f(t: pa.DataType, rowtype: list[ColumnInfo]) -> dict:
return with_sf_metadata(pa.schema([pa.field(str(t), t)]), rowtype).field(0).metadata

assert f(pa.string()) == {b"logicalType": b"TEXT", b"charLength": b"16777216"}
assert f(pa.decimal128(10, 2)) == {b"logicalType": b"FIXED", b"precision": b"10", b"scale": b"2"}
assert f(pa.string(), rowtype(["VARCHAR"])) == {
b"logicalType": b"TEXT",
b"precision": b"38",
b"scale": b"0",
b"charLength": b"16777216",
}
assert f(pa.decimal128(10, 2), rowtype(["DECIMAL(10,2)"])) == {
b"logicalType": b"FIXED",
b"precision": b"10",
b"scale": b"2",
b"charLength": b"0",
}


def test_ipc_writes_sf_metadata() -> None:
Expand All @@ -23,14 +38,16 @@ def test_ipc_writes_sf_metadata() -> None:
)

table = pa.Table.from_pandas(df)
table_bytes = to_ipc(table)
table_bytes = to_ipc(table, rowtype(["VARCHAR"]))

batch = next(iter(pa.ipc.open_stream(table_bytes)))

# field and schema metadata is ignored
assert pa.table(batch) == table
assert batch.schema.field(0).metadata == {
b"logicalType": b"TEXT",
b"precision": b"38",
b"scale": b"0",
b"charLength": b"16777216",
}

Expand Down

0 comments on commit e98d227

Please sign in to comment.