Skip to content

Commit

Permalink
CrateDB vector: Improve SQLAlchemy model factory
Browse files Browse the repository at this point in the history
From now on, _all_ instances of SQLAlchemy model types will be created
at runtime through the `ModelFactory` utility.

By using `__table_args__ = {"keep_existing": True}` on the ORM entity
definitions, this seems to work well, even with multiple invocations
of `CrateDBVectorSearch.from_texts()` using different `collection_name`
argument values.

While being at it, this patch also fixes a few linter errors.
  • Loading branch information
amotl committed Nov 20, 2023
1 parent 93e8970 commit c76dc83
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 122 deletions.
3 changes: 1 addition & 2 deletions libs/langchain/langchain/vectorstores/cratedb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .base import BaseModel, CrateDBVectorSearch
from .base import CrateDBVectorSearch

__all__ = [
"BaseModel",
"CrateDBVectorSearch",
]
68 changes: 37 additions & 31 deletions libs/langchain/langchain/vectorstores/cratedb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import enum
import math
import uuid
from typing import (
Any,
Callable,
Expand All @@ -20,11 +19,12 @@
polyfill_refresh_after_dml,
refresh_table,
)
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.orm import sessionmaker

from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.schema.embeddings import Embeddings
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.cratedb.model import ModelFactory
from langchain.vectorstores.pgvector import PGVector


Expand All @@ -38,23 +38,10 @@ class DistanceStrategy(str, enum.Enum):

DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.EUCLIDEAN

Base = declarative_base() # type: Any
# Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any

_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"


def generate_uuid() -> str:
return str(uuid.uuid4())


class BaseModel(Base):
"""Base model for the SQL stores."""

__abstract__ = True
uuid = sqlalchemy.Column(sqlalchemy.String, primary_key=True, default=generate_uuid)


def _results_to_docs(docs_and_scores: Any) -> List[Document]:
"""Return docs from docs and scores."""
return [doc for doc, _ in docs_and_scores]
Expand Down Expand Up @@ -114,30 +101,47 @@ def __post_init__(

# Need to defer initialization, because dimension size
# can only be figured out at runtime.
self.CollectionStore = None
self.EmbeddingStore = None
self.BaseModel = None
self.CollectionStore = None # type: ignore[assignment]
self.EmbeddingStore = None # type: ignore[assignment]

def __del__(self) -> None:
"""
Work around premature session close.
sqlalchemy.orm.exc.DetachedInstanceError: Parent instance <CollectionStore at 0x1212ca3d0> is not bound
to a Session; lazy load operation of attribute 'embeddings' cannot proceed.
-- https://docs.sqlalchemy.org/en/20/errors.html#error-bhk3
TODO: Review!
""" # noqa: E501
pass

def _init_models(self, embedding: List[float]):
def _init_models(self, embedding: List[float]) -> None:
"""
Create SQLAlchemy models at runtime, when not established yet.
"""

# TODO: Use a better way to run this only once.
if self.CollectionStore is not None and self.EmbeddingStore is not None:
return

size = len(embedding)
self._init_models_with_dimensionality(size=size)

def _init_models_with_dimensionality(self, size: int):
from langchain.vectorstores.cratedb.model import model_factory

self.CollectionStore, self.EmbeddingStore = model_factory(dimensions=size)
def _init_models_with_dimensionality(self, size: int) -> None:
mf = ModelFactory(dimensions=size)
self.BaseModel, self.CollectionStore, self.EmbeddingStore = (
mf.BaseModel, # type: ignore[assignment]
mf.CollectionStore,
mf.EmbeddingStore,
)

def get_collection(
self, session: sqlalchemy.orm.Session
) -> Optional["CollectionStore"]:
def get_collection(self, session: sqlalchemy.orm.Session) -> Any:
if self.CollectionStore is None:
raise RuntimeError(
"Collection can't be accessed without specifying dimension size of embedding vectors"
"Collection can't be accessed without specifying "
"dimension size of embedding vectors"
)
return self.CollectionStore.get_by_name(session, self.collection_name)

Expand Down Expand Up @@ -170,15 +174,17 @@ def add_embeddings(

def create_tables_if_not_exists(self) -> None:
"""
Need to overwrite because `Base` is different from upstream.
Need to overwrite because this `Base` is different from parent's `Base`.
"""
Base.metadata.create_all(self._engine)
mf = ModelFactory()
mf.Base.metadata.create_all(self._engine)

def drop_tables(self) -> None:
"""
Need to overwrite because `Base` is different from upstream.
Need to overwrite because this `Base` is different from parent's `Base`.
"""
Base.metadata.drop_all(self._engine)
mf = ModelFactory()
mf.Base.metadata.drop_all(self._engine)

def delete(
self,
Expand Down
168 changes: 98 additions & 70 deletions libs/langchain/langchain/vectorstores/cratedb/model.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,112 @@
from functools import lru_cache
from typing import Optional, Tuple
import uuid
from typing import Any, Optional, Tuple

import sqlalchemy
from crate.client.sqlalchemy.types import ObjectType
from sqlalchemy.orm import Session, relationship
from sqlalchemy.orm import Session, declarative_base, relationship

from langchain.vectorstores.cratedb.base import BaseModel
from langchain.vectorstores.cratedb.sqlalchemy_type import FloatVector


@lru_cache
def model_factory(dimensions: int):
class CollectionStore(BaseModel):
"""Collection store."""

__tablename__ = "collection"

name = sqlalchemy.Column(sqlalchemy.String)
cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType)

embeddings = relationship(
"EmbeddingStore",
back_populates="collection",
passive_deletes=True,
)

@classmethod
def get_by_name(
cls, session: Session, name: str
) -> Optional["CollectionStore"]:
try:
return (
session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] # noqa: E501
)
except sqlalchemy.exc.ProgrammingError as ex:
if "RelationUnknown" not in str(ex):
raise
return None

@classmethod
def get_or_create(
cls,
session: Session,
name: str,
cmetadata: Optional[dict] = None,
) -> Tuple["CollectionStore", bool]:
"""
Get or create a collection.
Returns [Collection, bool] where the bool is True if the collection was created.
"""
created = False
collection = cls.get_by_name(session, name)
if collection:
def generate_uuid() -> str:
return str(uuid.uuid4())


class ModelFactory:
"""Provide SQLAlchemy model objects at runtime."""

def __init__(self, dimensions: Optional[int] = None):
# While it does not have any function here, you will still need to supply a
# dummy dimension size value for operations like deleting records.
self.dimensions = dimensions or 1024

Base: Any = declarative_base()

# Optional: Use a custom schema for the langchain tables.
# Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any

class BaseModel(Base):
"""Base model for the SQL stores."""

__abstract__ = True
uuid = sqlalchemy.Column(
sqlalchemy.String, primary_key=True, default=generate_uuid
)

class CollectionStore(BaseModel):
"""Collection store."""

__tablename__ = "collection"
__table_args__ = {"keep_existing": True}

name = sqlalchemy.Column(sqlalchemy.String)
cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType)

embeddings = relationship(
"EmbeddingStore",
back_populates="collection",
passive_deletes=True,
)

@classmethod
def get_by_name(
cls, session: Session, name: str
) -> Optional["CollectionStore"]:
try:
return (
session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] # noqa: E501
)
except sqlalchemy.exc.ProgrammingError as ex:
if "RelationUnknown" not in str(ex):
raise
return None

@classmethod
def get_or_create(
cls,
session: Session,
name: str,
cmetadata: Optional[dict] = None,
) -> Tuple["CollectionStore", bool]:
"""
Get or create a collection.
Returns [Collection, bool] where the bool is True
if the collection was created.
"""
created = False
collection = cls.get_by_name(session, name)
if collection:
return collection, created

collection = cls(name=name, cmetadata=cmetadata)
session.add(collection)
session.commit()
created = True
return collection, created

collection = cls(name=name, cmetadata=cmetadata)
session.add(collection)
session.commit()
created = True
return collection, created
class EmbeddingStore(BaseModel):
"""Embedding store."""

class EmbeddingStore(BaseModel):
"""Embedding store."""
__tablename__ = "embedding"
__table_args__ = {"keep_existing": True}

__tablename__ = "embedding"
collection_id = sqlalchemy.Column(
sqlalchemy.String,
sqlalchemy.ForeignKey(
f"{CollectionStore.__tablename__}.uuid",
ondelete="CASCADE",
),
)
collection = relationship("CollectionStore", back_populates="embeddings")

collection_id = sqlalchemy.Column(
sqlalchemy.String,
sqlalchemy.ForeignKey(
f"{CollectionStore.__tablename__}.uuid",
ondelete="CASCADE",
),
)
collection = relationship("CollectionStore", back_populates="embeddings")
embedding = sqlalchemy.Column(FloatVector(self.dimensions))
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType, nullable=True)

embedding = sqlalchemy.Column(FloatVector(dimensions))
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType, nullable=True)
# custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)

# custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)

return CollectionStore, EmbeddingStore
self.Base = Base
self.BaseModel = BaseModel
self.CollectionStore = CollectionStore
self.EmbeddingStore = EmbeddingStore
Loading

0 comments on commit c76dc83

Please sign in to comment.