-
Notifications
You must be signed in to change notification settings - Fork 61
/
Copy pathconftest.py
119 lines (97 loc) · 3.93 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""Fixtures for the tests."""
import os
import socket
import tempfile
from collections.abc import Generator
from pathlib import Path
import pytest
from sqlalchemy import create_engine, text
from raglite import RAGLiteConfig, insert_document
POSTGRES_URL = "postgresql+pg8000://raglite_user:raglite_password@postgres:5432/postgres"
def is_postgres_running() -> bool:
"""Check if PostgreSQL is running."""
try:
with socket.create_connection(("postgres", 5432), timeout=1):
return True
except OSError:
return False
def is_openai_available() -> bool:
"""Check if an OpenAI API key is set."""
return bool(os.environ.get("OPENAI_API_KEY"))
def pytest_sessionstart(session: pytest.Session) -> None:
"""Reset the PostgreSQL and SQLite databases."""
if is_postgres_running():
engine = create_engine(POSTGRES_URL, isolation_level="AUTOCOMMIT")
with engine.connect() as conn:
for variant in ["local", "remote"]:
conn.execute(text(f"DROP DATABASE IF EXISTS raglite_test_{variant}"))
conn.execute(text(f"CREATE DATABASE raglite_test_{variant}"))
@pytest.fixture(scope="session")
def sqlite_url() -> Generator[str, None, None]:
"""Create a temporary SQLite database file and return the database URL."""
with tempfile.TemporaryDirectory() as temp_dir:
db_file = Path(temp_dir) / "raglite_test.sqlite"
yield f"sqlite:///{db_file}"
@pytest.fixture(
scope="session",
params=[
pytest.param("sqlite", id="sqlite"),
pytest.param(
POSTGRES_URL,
id="postgres",
marks=pytest.mark.skipif(not is_postgres_running(), reason="PostgreSQL is not running"),
),
],
)
def database(request: pytest.FixtureRequest) -> str:
"""Get a database URL to test RAGLite with."""
db_url: str = (
request.getfixturevalue("sqlite_url") if request.param == "sqlite" else request.param
)
return db_url
@pytest.fixture(
scope="session",
params=[
pytest.param(
(
"llama-cpp-python/bartowski/Llama-3.2-3B-Instruct-GGUF/*Q4_K_M.gguf@4096",
"llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@1024", # More context degrades performance.
),
id="llama_3.2_3B-bge_m3",
),
pytest.param(
("gpt-4o-mini", "text-embedding-3-small"),
id="gpt_4o_mini-text_embedding_3_small",
marks=pytest.mark.skipif(not is_openai_available(), reason="OpenAI API key is not set"),
),
],
)
def llm_embedder(request: pytest.FixtureRequest) -> str:
"""Get an LLM and embedder pair to test RAGLite with."""
llm_embedder: str = request.param
return llm_embedder
@pytest.fixture(scope="session")
def llm(llm_embedder: tuple[str, str]) -> str:
"""Get an LLM to test RAGLite with."""
llm, _ = llm_embedder
return llm
@pytest.fixture(scope="session")
def embedder(llm_embedder: tuple[str, str]) -> str:
"""Get an embedder to test RAGLite with."""
_, embedder = llm_embedder
return embedder
@pytest.fixture(scope="session")
def raglite_test_config(database: str, llm: str, embedder: str) -> RAGLiteConfig:
"""Create a lightweight in-memory config for testing SQLite and PostgreSQL."""
# Select the database based on the embedder.
variant = "local" if embedder.startswith("llama-cpp-python") else "remote"
if "postgres" in database:
database = database.replace("/postgres", f"/raglite_test_{variant}")
elif "sqlite" in database:
database = database.replace(".sqlite", f"_{variant}.sqlite")
# Create a RAGLite config for the given database and embedder.
db_config = RAGLiteConfig(db_url=database, llm=llm, embedder=embedder)
# Insert a document and update the index.
doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
insert_document(doc_path, config=db_config)
return db_config