forked from quantalogic/quantalogic
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path09-sql-query.py
executable file
·233 lines (190 loc) · 8.21 KB
/
09-sql-query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
#!/usr/bin/env -S uv run
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "quantalogic",
# ]
# ///
import argparse
import os
from typing import Any
import loguru
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from rich.prompt import Confirm, Prompt
from rich.syntax import Syntax
from quantalogic import Agent
from quantalogic.console_print_events import console_print_events
from quantalogic.console_print_token import console_print_token
from quantalogic.tools import GenerateDatabaseReportTool, InputQuestionTool, SQLQueryTool
from quantalogic.tools.utils import create_sample_database
# Parse command-line arguments
parser = argparse.ArgumentParser(
description="Interactive SQL query interface powered by AI",
epilog="""
Examples:
python 09-sql-query.py --model deepseek/deepseek-chat
python 09-sql-query.py --help
Available models:
- deepseek/deepseek-chat (default)
- openai/gpt-4o-mini
- anthropic/claude-3.5-sonnet
- openrouter/deepseek/deepseek-chat
- openrouter/mistralai/mistral-large-2411
""",
)
parser.add_argument(
"--model",
type=str,
default="deepseek/deepseek-chat",
help="Model name to use (default: deepseek/deepseek-chat) or any of the following: openai/gpt-4o-mini, anthropic/claude-3.5-sonnet, openrouter/deepseek/deepseek-chat, openrouter/mistralai/mistral-large-2411",
)
args = parser.parse_args()
# Handle case where no arguments are provided
if not args.model:
args.model = "deepseek/deepseek-chat" # Default model
MODEL_NAME = args.model
# Using specified model for cost-effectiveness and performance
# Can be switched to OpenAI/Anthropic models if needed for specific use cases
# MODEL_NAME = "deepseek/deepseek-chat" # Default: Best balance of cost and capability
# Alternative options (uncomment to use):
# MODEL_NAME = "openai/gpt-4o-mini" # For OpenAI ecosystem compatibility
# MODEL_NAME = "anthropic/claude-3.5-sonnet" # For advanced reasoning tasks
# MODEL_NAME = "openrouter/deepseek/deepseek-chat" # Via OpenRouter API
# MODEL_NAME = "openrouter/mistral-large" # Mistral Large via OpenRouter API
# Verify required API keys based on selected model
if MODEL_NAME.startswith("deepseek") and not os.environ.get("DEEPSEEK_API_KEY"):
raise ValueError("DEEPSEEK_API_KEY environment variable is not set")
elif MODEL_NAME.startswith("openai") and not os.environ.get("OPENAI_API_KEY"):
raise ValueError("OPENAI_API_KEY environment variable is not set")
elif MODEL_NAME.startswith("anthropic") and not os.environ.get("ANTHROPIC_API_KEY"):
raise ValueError("ANTHROPIC_API_KEY environment variable is not set")
elif MODEL_NAME.startswith("openrouter") and not os.environ.get("OPENROUTER_API_KEY"):
raise ValueError("OPENROUTER_API_KEY environment variable is not set")
# Database connection configuration
# Prefers environment variable for security (avoids hardcoding credentials)
# Falls back to interactive prompt for local development convenience
# Defaults to SQLite for quick setup and demonstration purposes
console = Console()
db_conn = os.environ.get("DB_CONNECTION_STRING") or Prompt.ask(
"[bold]Enter database connection string[/bold]", default="sqlite:///sample.db", console=console
)
def get_database_report():
"""Generate a database report using the GenerateDatabaseReportTool."""
tool = GenerateDatabaseReportTool(connection_string=db_conn)
return tool.execute()
# Initialize agent with SQL capabilities
def create_agent(connection_string: str) -> Agent:
"""Create an agent with SQL capabilities."""
agent = Agent(
model_name=MODEL_NAME,
tools=[
SQLQueryTool(connection_string=connection_string),
InputQuestionTool(),
],
specific_expertise="SQL Query Assistant able to generate SQL queries based on natural language questions",
task_to_solve="Generate SQL queries based on user questions about the database"
)
return agent
agent = create_agent(db_conn)
# Event-driven architecture for better observability and control
# Handles key lifecycle events to provide real-time feedback
# Tracks: task states, tool execution, and error conditions
agent.event_emitter.on(
[
"task_complete", # Final task state
"task_think_start", # Agent begins processing
"task_think_end", # Agent finishes processing
"tool_execution_start", # Tool begins execution
"tool_execution_end", # Tool completes execution
"error_max_iterations_reached", # Safety limit exceeded
],
console_print_events, # Unified event display handler
)
# Visual feedback system using spinner
# Global state ensures only one spinner runs at a time
current_spinner = None # Tracks active spinner instance
def start_spinner(event: str, data: Any | None = None) -> None:
"""Start spinner to indicate processing state.
Uses global state to prevent multiple concurrent spinners.
"""
global current_spinner
current_spinner = console.status("[bold green]Analyzing query...[/bold green]", spinner="dots")
current_spinner.start()
def stop_spinner(event: str, data: Any | None = None) -> None:
"""Cleanly stop spinner and release resources.
Prevents memory leaks from orphaned spinners.
"""
global current_spinner
if current_spinner:
current_spinner.stop()
current_spinner = None # Clear reference to allow garbage collection
# Updated event handling
loguru.logger.info("Registering event listeners")
agent.event_emitter.on("task_solve_start", start_spinner)
agent.event_emitter.on("stream_chunk", stop_spinner)
agent.event_emitter.on("stream_chunk", console_print_token)
agent.event_emitter.on("task_solve_end", stop_spinner)
def format_markdown(result: str) -> Panel:
"""Render markdown content with professional styling."""
if "```sql" in result:
result = Syntax(result, "sql", theme="monokai", line_numbers=False)
return Panel.fit(result, title="Generated SQL", border_style="blue")
md = Markdown(result, code_theme="dracula", inline_code_theme="dracula", justify="left")
return Panel.fit(
md,
title="[bold]Query Results[/bold]",
border_style="bright_cyan",
padding=(1, 2),
subtitle="📊 Database Results",
)
def query_loop():
"""Interactive query interface with error recovery.
Designed for continuous operation with graceful exit handling.
Provides clear visual feedback and error recovery options.
"""
console.print(
Panel.fit(
"[bold reverse] 💽 SQL QUERY INTERFACE [/bold reverse]",
border_style="bright_magenta",
subtitle="Type 'exit' to quit",
)
)
# Getting database report
console.print("Generating database report...")
database_report = get_database_report()
console.print(format_markdown(database_report), width=90)
while True:
try:
question = Prompt.ask("\n[bold cyan]❓ Your question[/bold cyan]")
if question.lower() in ("exit", "quit", "q"):
break
task_description = f"""
As an expert database analyst, perform these steps:
1. Analyze the question: "{question}"
2. Generate appropriate SQL query
3. Execute the SQL query and present the results
The database context is as follows, strictly respect it:
{database_report}
"""
result = agent.solve_task(task_description, streaming=True)
console.print(format_markdown(result), width=90)
if not Confirm.ask("[bold]Submit another query?[/bold]", default=True):
break
except Exception as e:
console.print(
Panel.fit(f"[red bold]ERROR:[/red bold] {str(e)}", border_style="red", title="🚨 Processing Error")
)
if not Confirm.ask("[bold]Try another question?[/bold]", default=True):
break
console.print(
Panel.fit(
"[bold green]Session terminated[/bold green]",
border_style="bright_green",
subtitle="Thank you for using the SQL interface!",
)
)
if __name__ == "__main__":
create_sample_database("sample.db")
query_loop()