From c76dc8364a85a60c535973587d15c111f664ed92 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Mon, 20 Nov 2023 21:34:09 +0100 Subject: [PATCH] CrateDB vector: Improve SQLAlchemy model factory From now on, _all_ instances of SQLAlchemy model types will be created at runtime through the `ModelFactory` utility. By using `__table_args__ = {"keep_existing": True}` on the ORM entity definitions, this seems to work well, even with multiple invocations of `CrateDBVectorSearch.from_texts()` using different `collection_name` argument values. While being at it, this patch also fixes a few linter errors. --- .../vectorstores/cratedb/__init__.py | 3 +- .../langchain/vectorstores/cratedb/base.py | 68 +++---- .../langchain/vectorstores/cratedb/model.py | 168 ++++++++++-------- .../vectorstores/test_cratedb.py | 37 ++-- 4 files changed, 154 insertions(+), 122 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/cratedb/__init__.py b/libs/langchain/langchain/vectorstores/cratedb/__init__.py index 303a52babeaea..14b02ad126867 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/__init__.py +++ b/libs/langchain/langchain/vectorstores/cratedb/__init__.py @@ -1,6 +1,5 @@ -from .base import BaseModel, CrateDBVectorSearch +from .base import CrateDBVectorSearch __all__ = [ - "BaseModel", "CrateDBVectorSearch", ] diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py index e7c651bea9822..ec3c4c19d70a6 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/base.py +++ b/libs/langchain/langchain/vectorstores/cratedb/base.py @@ -2,7 +2,6 @@ import enum import math -import uuid from typing import ( Any, Callable, @@ -20,11 +19,12 @@ polyfill_refresh_after_dml, refresh_table, ) -from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy.orm import sessionmaker from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_dict_or_env +from langchain.vectorstores.cratedb.model import ModelFactory from langchain.vectorstores.pgvector import PGVector @@ -38,23 +38,10 @@ class DistanceStrategy(str, enum.Enum): DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.EUCLIDEAN -Base = declarative_base() # type: Any -# Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" -def generate_uuid() -> str: - return str(uuid.uuid4()) - - -class BaseModel(Base): - """Base model for the SQL stores.""" - - __abstract__ = True - uuid = sqlalchemy.Column(sqlalchemy.String, primary_key=True, default=generate_uuid) - - def _results_to_docs(docs_and_scores: Any) -> List[Document]: """Return docs from docs and scores.""" return [doc for doc, _ in docs_and_scores] @@ -114,30 +101,47 @@ def __post_init__( # Need to defer initialization, because dimension size # can only be figured out at runtime. - self.CollectionStore = None - self.EmbeddingStore = None + self.BaseModel = None + self.CollectionStore = None # type: ignore[assignment] + self.EmbeddingStore = None # type: ignore[assignment] + + def __del__(self) -> None: + """ + Work around premature session close. + + sqlalchemy.orm.exc.DetachedInstanceError: Parent instance is not bound + to a Session; lazy load operation of attribute 'embeddings' cannot proceed. + -- https://docs.sqlalchemy.org/en/20/errors.html#error-bhk3 + + TODO: Review! + """ # noqa: E501 + pass - def _init_models(self, embedding: List[float]): + def _init_models(self, embedding: List[float]) -> None: """ Create SQLAlchemy models at runtime, when not established yet. """ + + # TODO: Use a better way to run this only once. if self.CollectionStore is not None and self.EmbeddingStore is not None: return size = len(embedding) self._init_models_with_dimensionality(size=size) - def _init_models_with_dimensionality(self, size: int): - from langchain.vectorstores.cratedb.model import model_factory - - self.CollectionStore, self.EmbeddingStore = model_factory(dimensions=size) + def _init_models_with_dimensionality(self, size: int) -> None: + mf = ModelFactory(dimensions=size) + self.BaseModel, self.CollectionStore, self.EmbeddingStore = ( + mf.BaseModel, # type: ignore[assignment] + mf.CollectionStore, + mf.EmbeddingStore, + ) - def get_collection( - self, session: sqlalchemy.orm.Session - ) -> Optional["CollectionStore"]: + def get_collection(self, session: sqlalchemy.orm.Session) -> Any: if self.CollectionStore is None: raise RuntimeError( - "Collection can't be accessed without specifying dimension size of embedding vectors" + "Collection can't be accessed without specifying " + "dimension size of embedding vectors" ) return self.CollectionStore.get_by_name(session, self.collection_name) @@ -170,15 +174,17 @@ def add_embeddings( def create_tables_if_not_exists(self) -> None: """ - Need to overwrite because `Base` is different from upstream. + Need to overwrite because this `Base` is different from parent's `Base`. """ - Base.metadata.create_all(self._engine) + mf = ModelFactory() + mf.Base.metadata.create_all(self._engine) def drop_tables(self) -> None: """ - Need to overwrite because `Base` is different from upstream. + Need to overwrite because this `Base` is different from parent's `Base`. """ - Base.metadata.drop_all(self._engine) + mf = ModelFactory() + mf.Base.metadata.drop_all(self._engine) def delete( self, diff --git a/libs/langchain/langchain/vectorstores/cratedb/model.py b/libs/langchain/langchain/vectorstores/cratedb/model.py index ee42e7269dc9d..b8b14c05010f5 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/model.py +++ b/libs/langchain/langchain/vectorstores/cratedb/model.py @@ -1,84 +1,112 @@ -from functools import lru_cache -from typing import Optional, Tuple +import uuid +from typing import Any, Optional, Tuple import sqlalchemy from crate.client.sqlalchemy.types import ObjectType -from sqlalchemy.orm import Session, relationship +from sqlalchemy.orm import Session, declarative_base, relationship -from langchain.vectorstores.cratedb.base import BaseModel from langchain.vectorstores.cratedb.sqlalchemy_type import FloatVector -@lru_cache -def model_factory(dimensions: int): - class CollectionStore(BaseModel): - """Collection store.""" - - __tablename__ = "collection" - - name = sqlalchemy.Column(sqlalchemy.String) - cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType) - - embeddings = relationship( - "EmbeddingStore", - back_populates="collection", - passive_deletes=True, - ) - - @classmethod - def get_by_name( - cls, session: Session, name: str - ) -> Optional["CollectionStore"]: - try: - return ( - session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] # noqa: E501 - ) - except sqlalchemy.exc.ProgrammingError as ex: - if "RelationUnknown" not in str(ex): - raise - return None - - @classmethod - def get_or_create( - cls, - session: Session, - name: str, - cmetadata: Optional[dict] = None, - ) -> Tuple["CollectionStore", bool]: - """ - Get or create a collection. - Returns [Collection, bool] where the bool is True if the collection was created. - """ - created = False - collection = cls.get_by_name(session, name) - if collection: +def generate_uuid() -> str: + return str(uuid.uuid4()) + + +class ModelFactory: + """Provide SQLAlchemy model objects at runtime.""" + + def __init__(self, dimensions: Optional[int] = None): + # While it does not have any function here, you will still need to supply a + # dummy dimension size value for operations like deleting records. + self.dimensions = dimensions or 1024 + + Base: Any = declarative_base() + + # Optional: Use a custom schema for the langchain tables. + # Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any + + class BaseModel(Base): + """Base model for the SQL stores.""" + + __abstract__ = True + uuid = sqlalchemy.Column( + sqlalchemy.String, primary_key=True, default=generate_uuid + ) + + class CollectionStore(BaseModel): + """Collection store.""" + + __tablename__ = "collection" + __table_args__ = {"keep_existing": True} + + name = sqlalchemy.Column(sqlalchemy.String) + cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType) + + embeddings = relationship( + "EmbeddingStore", + back_populates="collection", + passive_deletes=True, + ) + + @classmethod + def get_by_name( + cls, session: Session, name: str + ) -> Optional["CollectionStore"]: + try: + return ( + session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] # noqa: E501 + ) + except sqlalchemy.exc.ProgrammingError as ex: + if "RelationUnknown" not in str(ex): + raise + return None + + @classmethod + def get_or_create( + cls, + session: Session, + name: str, + cmetadata: Optional[dict] = None, + ) -> Tuple["CollectionStore", bool]: + """ + Get or create a collection. + Returns [Collection, bool] where the bool is True + if the collection was created. + """ + created = False + collection = cls.get_by_name(session, name) + if collection: + return collection, created + + collection = cls(name=name, cmetadata=cmetadata) + session.add(collection) + session.commit() + created = True return collection, created - collection = cls(name=name, cmetadata=cmetadata) - session.add(collection) - session.commit() - created = True - return collection, created + class EmbeddingStore(BaseModel): + """Embedding store.""" - class EmbeddingStore(BaseModel): - """Embedding store.""" + __tablename__ = "embedding" + __table_args__ = {"keep_existing": True} - __tablename__ = "embedding" + collection_id = sqlalchemy.Column( + sqlalchemy.String, + sqlalchemy.ForeignKey( + f"{CollectionStore.__tablename__}.uuid", + ondelete="CASCADE", + ), + ) + collection = relationship("CollectionStore", back_populates="embeddings") - collection_id = sqlalchemy.Column( - sqlalchemy.String, - sqlalchemy.ForeignKey( - f"{CollectionStore.__tablename__}.uuid", - ondelete="CASCADE", - ), - ) - collection = relationship("CollectionStore", back_populates="embeddings") + embedding = sqlalchemy.Column(FloatVector(self.dimensions)) + document = sqlalchemy.Column(sqlalchemy.String, nullable=True) + cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType, nullable=True) - embedding = sqlalchemy.Column(FloatVector(dimensions)) - document = sqlalchemy.Column(sqlalchemy.String, nullable=True) - cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType, nullable=True) + # custom_id : any user defined id + custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) - # custom_id : any user defined id - custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) - - return CollectionStore, EmbeddingStore + self.Base = Base + self.BaseModel = BaseModel + self.CollectionStore = CollectionStore + self.EmbeddingStore = EmbeddingStore diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py index 8f62919842fc0..8f054fc07a0b3 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py @@ -5,7 +5,7 @@ docker-compose -f cratedb.yml up """ import os -from typing import List, Tuple +from typing import Generator, List, Tuple import pytest import sqlalchemy as sa @@ -13,7 +13,8 @@ from sqlalchemy.orm import Session from langchain.docstore.document import Document -from langchain.vectorstores.cratedb import BaseModel, CrateDBVectorSearch +from langchain.vectorstores.cratedb import CrateDBVectorSearch +from langchain.vectorstores.cratedb.model import ModelFactory from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, FakeEmbeddings, @@ -44,7 +45,7 @@ def engine() -> sa.Engine: @pytest.fixture -def session(engine) -> sa.orm.Session: +def session(engine: sa.Engine) -> Generator[sa.orm.Session, None, None]: with engine.connect() as conn: with Session(conn) as session: yield session @@ -56,7 +57,8 @@ def drop_tables(engine: sa.Engine) -> None: Drop database tables. """ try: - BaseModel.metadata.drop_all(engine, checkfirst=False) + mf = ModelFactory() + mf.BaseModel.metadata.drop_all(engine, checkfirst=False) except Exception as ex: if "RelationUnknown" not in str(ex): raise @@ -69,18 +71,13 @@ def prune_tables(engine: sa.Engine) -> None: """ with engine.connect() as conn: with Session(conn) as session: - from langchain.vectorstores.cratedb.model import model_factory - - # While it does not have any function here, you will still need to supply a - # dummy dimension size value for deleting records from tables. - CollectionStore, EmbeddingStore = model_factory(dimensions=1024) - + mf = ModelFactory() try: - session.query(CollectionStore).delete() + session.query(mf.CollectionStore).delete() except ProgrammingError: pass try: - session.query(EmbeddingStore).delete() + session.query(mf.EmbeddingStore).delete() except ProgrammingError: pass @@ -99,13 +96,13 @@ def decode_output( return documents, scores -def ensure_collection(session: sa.orm.Session, name: str): +def ensure_collection(session: sa.orm.Session, name: str) -> None: """ Create a (fake) collection item. """ session.execute( sa.text( - f""" + """ CREATE TABLE IF NOT EXISTS collection ( uuid TEXT, name TEXT, @@ -116,7 +113,7 @@ def ensure_collection(session: sa.orm.Session, name: str): ) session.execute( sa.text( - f""" + """ CREATE TABLE IF NOT EXISTS embedding ( uuid TEXT, collection_id TEXT, @@ -131,7 +128,8 @@ def ensure_collection(session: sa.orm.Session, name: str): try: session.execute( sa.text( - f"INSERT INTO collection (uuid, name, cmetadata) VALUES ('uuid-{name}', '{name}', {{}});" + f"INSERT INTO collection (uuid, name, cmetadata) " + f"VALUES ('uuid-{name}', '{name}', {{}});" ) ) session.execute(sa.text("REFRESH TABLE collection")) @@ -325,7 +323,7 @@ def test_cratedb_collection_with_metadata() -> None: def test_cratedb_collection_no_embedding_dimension() -> None: """Test end to end collection construction""" cratedb_vector = CrateDBVectorSearch( - embedding_function=None, + embedding_function=None, # type: ignore[arg-type] connection_string=CONNECTION_STRING, pre_delete_collection=True, ) @@ -333,11 +331,12 @@ def test_cratedb_collection_no_embedding_dimension() -> None: with pytest.raises(RuntimeError) as ex: cratedb_vector.get_collection(session) assert ex.match( - "Collection can't be accessed without specifying dimension size of embedding vectors" + "Collection can't be accessed without specifying " + "dimension size of embedding vectors" ) -def test_cratedb_collection_read_only(session) -> None: +def test_cratedb_collection_read_only(session: Session) -> None: """ Test using a collection, without adding any embeddings upfront.