Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CrateDB vector: Improve SQLAlchemy model factory #13

Merged
merged 1 commit into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions libs/langchain/langchain/vectorstores/cratedb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .base import BaseModel, CrateDBVectorSearch
from .base import CrateDBVectorSearch

__all__ = [
"BaseModel",
"CrateDBVectorSearch",
]
68 changes: 37 additions & 31 deletions libs/langchain/langchain/vectorstores/cratedb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import enum
import math
import uuid
from typing import (
Any,
Callable,
Expand All @@ -20,11 +19,12 @@
polyfill_refresh_after_dml,
refresh_table,
)
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.orm import sessionmaker

from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.schema.embeddings import Embeddings
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.cratedb.model import ModelFactory
from langchain.vectorstores.pgvector import PGVector


Expand All @@ -38,23 +38,10 @@ class DistanceStrategy(str, enum.Enum):

DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.EUCLIDEAN

Base = declarative_base() # type: Any
# Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any

_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"


def generate_uuid() -> str:
return str(uuid.uuid4())


class BaseModel(Base):
"""Base model for the SQL stores."""

__abstract__ = True
uuid = sqlalchemy.Column(sqlalchemy.String, primary_key=True, default=generate_uuid)


def _results_to_docs(docs_and_scores: Any) -> List[Document]:
"""Return docs from docs and scores."""
return [doc for doc, _ in docs_and_scores]
Expand Down Expand Up @@ -114,30 +101,47 @@ def __post_init__(

# Need to defer initialization, because dimension size
# can only be figured out at runtime.
self.CollectionStore = None
self.EmbeddingStore = None
self.BaseModel = None
self.CollectionStore = None # type: ignore[assignment]
self.EmbeddingStore = None # type: ignore[assignment]

def __del__(self) -> None:
"""
Work around premature session close.

sqlalchemy.orm.exc.DetachedInstanceError: Parent instance <CollectionStore at 0x1212ca3d0> is not bound
to a Session; lazy load operation of attribute 'embeddings' cannot proceed.
-- https://docs.sqlalchemy.org/en/20/errors.html#error-bhk3

TODO: Review!
""" # noqa: E501
pass

def _init_models(self, embedding: List[float]):
def _init_models(self, embedding: List[float]) -> None:
"""
Create SQLAlchemy models at runtime, when not established yet.
"""

# TODO: Use a better way to run this only once.
if self.CollectionStore is not None and self.EmbeddingStore is not None:
return

size = len(embedding)
self._init_models_with_dimensionality(size=size)

def _init_models_with_dimensionality(self, size: int):
from langchain.vectorstores.cratedb.model import model_factory

self.CollectionStore, self.EmbeddingStore = model_factory(dimensions=size)
def _init_models_with_dimensionality(self, size: int) -> None:
mf = ModelFactory(dimensions=size)
self.BaseModel, self.CollectionStore, self.EmbeddingStore = (
mf.BaseModel, # type: ignore[assignment]
mf.CollectionStore,
mf.EmbeddingStore,
)

def get_collection(
self, session: sqlalchemy.orm.Session
) -> Optional["CollectionStore"]:
def get_collection(self, session: sqlalchemy.orm.Session) -> Any:
if self.CollectionStore is None:
raise RuntimeError(
"Collection can't be accessed without specifying dimension size of embedding vectors"
"Collection can't be accessed without specifying "
"dimension size of embedding vectors"
)
return self.CollectionStore.get_by_name(session, self.collection_name)

Expand Down Expand Up @@ -170,15 +174,17 @@ def add_embeddings(

def create_tables_if_not_exists(self) -> None:
"""
Need to overwrite because `Base` is different from upstream.
Need to overwrite because this `Base` is different from parent's `Base`.
"""
Base.metadata.create_all(self._engine)
mf = ModelFactory()
mf.Base.metadata.create_all(self._engine)

def drop_tables(self) -> None:
"""
Need to overwrite because `Base` is different from upstream.
Need to overwrite because this `Base` is different from parent's `Base`.
"""
Base.metadata.drop_all(self._engine)
mf = ModelFactory()
mf.Base.metadata.drop_all(self._engine)

def delete(
self,
Expand Down
168 changes: 98 additions & 70 deletions libs/langchain/langchain/vectorstores/cratedb/model.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,112 @@
from functools import lru_cache
from typing import Optional, Tuple
import uuid
from typing import Any, Optional, Tuple

import sqlalchemy
from crate.client.sqlalchemy.types import ObjectType
from sqlalchemy.orm import Session, relationship
from sqlalchemy.orm import Session, declarative_base, relationship

from langchain.vectorstores.cratedb.base import BaseModel
from langchain.vectorstores.cratedb.sqlalchemy_type import FloatVector


@lru_cache
def model_factory(dimensions: int):
class CollectionStore(BaseModel):
"""Collection store."""

__tablename__ = "collection"

name = sqlalchemy.Column(sqlalchemy.String)
cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType)

embeddings = relationship(
"EmbeddingStore",
back_populates="collection",
passive_deletes=True,
)

@classmethod
def get_by_name(
cls, session: Session, name: str
) -> Optional["CollectionStore"]:
try:
return (
session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] # noqa: E501
)
except sqlalchemy.exc.ProgrammingError as ex:
if "RelationUnknown" not in str(ex):
raise
return None

@classmethod
def get_or_create(
cls,
session: Session,
name: str,
cmetadata: Optional[dict] = None,
) -> Tuple["CollectionStore", bool]:
"""
Get or create a collection.
Returns [Collection, bool] where the bool is True if the collection was created.
"""
created = False
collection = cls.get_by_name(session, name)
if collection:
def generate_uuid() -> str:
return str(uuid.uuid4())


class ModelFactory:
"""Provide SQLAlchemy model objects at runtime."""

def __init__(self, dimensions: Optional[int] = None):
# While it does not have any function here, you will still need to supply a
# dummy dimension size value for operations like deleting records.
self.dimensions = dimensions or 1024

Base: Any = declarative_base()

# Optional: Use a custom schema for the langchain tables.
# Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any

class BaseModel(Base):
"""Base model for the SQL stores."""

__abstract__ = True
uuid = sqlalchemy.Column(
sqlalchemy.String, primary_key=True, default=generate_uuid
)

class CollectionStore(BaseModel):
"""Collection store."""

__tablename__ = "collection"
__table_args__ = {"keep_existing": True}

name = sqlalchemy.Column(sqlalchemy.String)
cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType)

embeddings = relationship(
"EmbeddingStore",
back_populates="collection",
passive_deletes=True,
)

@classmethod
def get_by_name(
cls, session: Session, name: str
) -> Optional["CollectionStore"]:
try:
return (
session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] # noqa: E501
)
except sqlalchemy.exc.ProgrammingError as ex:
if "RelationUnknown" not in str(ex):
raise
return None

@classmethod
def get_or_create(
cls,
session: Session,
name: str,
cmetadata: Optional[dict] = None,
) -> Tuple["CollectionStore", bool]:
"""
Get or create a collection.
Returns [Collection, bool] where the bool is True
if the collection was created.
"""
created = False
collection = cls.get_by_name(session, name)
if collection:
return collection, created

collection = cls(name=name, cmetadata=cmetadata)
session.add(collection)
session.commit()
created = True
return collection, created

collection = cls(name=name, cmetadata=cmetadata)
session.add(collection)
session.commit()
created = True
return collection, created
class EmbeddingStore(BaseModel):
"""Embedding store."""

class EmbeddingStore(BaseModel):
"""Embedding store."""
__tablename__ = "embedding"
__table_args__ = {"keep_existing": True}

__tablename__ = "embedding"
collection_id = sqlalchemy.Column(
sqlalchemy.String,
sqlalchemy.ForeignKey(
f"{CollectionStore.__tablename__}.uuid",
ondelete="CASCADE",
),
)
collection = relationship("CollectionStore", back_populates="embeddings")

collection_id = sqlalchemy.Column(
sqlalchemy.String,
sqlalchemy.ForeignKey(
f"{CollectionStore.__tablename__}.uuid",
ondelete="CASCADE",
),
)
collection = relationship("CollectionStore", back_populates="embeddings")
embedding = sqlalchemy.Column(FloatVector(self.dimensions))
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType, nullable=True)

embedding = sqlalchemy.Column(FloatVector(dimensions))
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType, nullable=True)
# custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)

# custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)

return CollectionStore, EmbeddingStore
self.Base = Base
self.BaseModel = BaseModel
self.CollectionStore = CollectionStore
self.EmbeddingStore = EmbeddingStore
Loading