Skip to content

Commit

Permalink
Vector: Fix type checking and compatibility with SQLAlchemy 1.x
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Dec 22, 2023
1 parent 51e5874 commit 1070b37
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/sqlalchemy_cratedb/type/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
- The type implementation might want to be accompanied by corresponding support
for the `KNN_MATCH` function, similar to what the dialect already offers for
fulltext search through its `Match` predicate.
- After dropping support for SQLAlchemy 1.3, use
`class FloatVector(sa.TypeDecorator[t.Sequence[float]]):`
## Origin
This module is based on the corresponding pgvector implementation
Expand All @@ -44,7 +46,7 @@
__all__ = ["FloatVector"]


def from_db(value: t.Iterable) -> t.Optional[npt.ArrayLike]:
def from_db(value: t.Iterable) -> t.Optional["npt.ArrayLike"]:
import numpy as np

# from `pgvector.utils`
Expand Down Expand Up @@ -77,8 +79,7 @@ def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]:
return value


class FloatVector(sa.TypeDecorator[t.Sequence[float]]):

class FloatVector(sa.TypeDecorator):
"""
An improved implementation of the `FloatVector` data type for CrateDB,
compared to the previous implementation on behalf of the LangChain adapter.
Expand Down Expand Up @@ -146,14 +147,14 @@ def __init__(self, dimensions: int = None):
def as_generic(self):
return sa.ARRAY

def bind_processor(self, dialect: sa.Dialect) -> t.Callable:
def bind_processor(self, dialect: sa.engine.Dialect) -> t.Callable:
def process(value: t.Iterable) -> t.Optional[t.List]:
return to_db(value, self.dimensions)

return process

def result_processor(self, dialect: sa.Dialect, coltype: t.Any) -> t.Callable:
def process(value: t.Any) -> t.Optional[npt.ArrayLike]:
def result_processor(self, dialect: sa.engine.Dialect, coltype: t.Any) -> t.Callable:
def process(value: t.Any) -> t.Optional["npt.ArrayLike"]:
return from_db(value)

return process

0 comments on commit 1070b37

Please sign in to comment.