-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(sql): sql magic, support postgresql/mysql/sqlite
- Loading branch information
1 parent
affdd0e
commit 3ad3b2f
Showing
12 changed files
with
370 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# python generated files | ||
__pycache__/ | ||
*.py[oc] | ||
build/ | ||
dist/ | ||
wheels/ | ||
*.egg-info | ||
|
||
# venv | ||
.venv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
3.10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# libro-sql | ||
|
||
# 使用 | ||
|
||
## 加载 | ||
|
||
```shell | ||
%load_ext libro_sql | ||
``` | ||
|
||
# 设置 | ||
|
||
```python | ||
from libro_sql.database import db | ||
db.config({ | ||
'db_type': '', | ||
'username': '', | ||
'password': '', | ||
'host': '', | ||
'port': 5432, | ||
'database': '' | ||
}) | ||
``` | ||
|
||
# 执行 | ||
|
||
```python | ||
%%sql | ||
{"result_variable":"a", "sql_script":"select 1"} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
[project] | ||
name = "libro-sql" | ||
version = "0.1.2" | ||
description = "libro flow" | ||
authors = [ | ||
{ name = "brokun", email = "[email protected]" }, | ||
{ name = "sunshinesmilelk", email = "[email protected]" }, | ||
] | ||
dependencies = [ | ||
"ipython>=7.34.0", | ||
"sqlalchemy>=2.0.34", | ||
"pandas>=2.2.2", | ||
"pydantic>=2.9.1", | ||
"psycopg2-binary>=2.9.9", | ||
"pymysql>=1.1.1", | ||
] | ||
dev-dependencies = [] | ||
readme = "README.md" | ||
requires-python = ">= 3.10" | ||
|
||
[build-system] | ||
requires = ["hatchling"] | ||
build-backend = "hatchling.build" | ||
|
||
[tool.rye] | ||
managed = true | ||
|
||
[tool.hatch.metadata] | ||
allow-direct-references = true | ||
|
||
[tool.hatch.build.targets.wheel] | ||
packages = ["src/libro_sql"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from ._version import __version__ | ||
|
||
from .extensions import ( | ||
load_ipython_extension, | ||
unload_ipython_extension, | ||
_load_jupyter_server_extension, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# -*- coding: utf-8 -*- | ||
__version__ = "0.1.2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
|
||
from typing import Optional | ||
from pydantic import BaseModel | ||
from sqlalchemy import create_engine, text | ||
from sqlalchemy.exc import SQLAlchemyError | ||
import pandas as pd | ||
|
||
|
||
class DatabaseConfig(BaseModel): | ||
db_type: str | ||
username: str | ||
password: str | ||
host: str | ||
port: int | ||
database: str | ||
|
||
|
||
class Database: | ||
config: DatabaseConfig | ||
|
||
def __init__(self, config: DatabaseConfig): | ||
self.config = config | ||
self.engine = self._create_engine() | ||
|
||
def _create_engine(self): | ||
"""Create the SQLAlchemy engine based on the database type.""" | ||
config = self.config | ||
try: | ||
if config.db_type == 'postgresql': | ||
engine = create_engine( | ||
f'postgresql+psycopg2://{config.username}:{config.password}@{config.host}:{config.port}/{config.database}') | ||
elif config.db_type == 'mysql': | ||
engine = create_engine( | ||
f'mysql+pymysql://{config.username}:{config.password}@{config.host}:{config.port}/{config.database}') | ||
elif config.db_type == 'sqlite': | ||
engine = create_engine(f'sqlite:///{config.database}') | ||
else: | ||
raise ValueError( | ||
f"Unsupported database type: {config.db_type}") | ||
return engine | ||
except Exception as e: | ||
print(f"Error creating engine: {e}") | ||
raise | ||
|
||
def execute(self, query): | ||
"""Execute a SQL query or non-query and return the result. | ||
If the query is a SELECT statement, return the result as a DataFrame. | ||
For other statements (INSERT, UPDATE, DELETE), execute the statement and return the number of affected rows. | ||
""" | ||
with self.engine.connect() as connection: | ||
try: | ||
result = connection.execute(text(query)) | ||
if result.returns_rows: | ||
# Fetch all rows and construct DataFrame with column names | ||
rows = result.fetchall() | ||
if rows: | ||
# Debug: Print fetched rows | ||
df = pd.DataFrame(rows, columns=result.keys()) | ||
else: | ||
df = pd.DataFrame() # Return empty DataFrame if no rows | ||
return df | ||
else: | ||
if result.rowcount is not None: | ||
connection.commit() | ||
return result.rowcount | ||
else: | ||
return result | ||
except SQLAlchemyError as e: | ||
print(f"Error executing query: {e}") | ||
raise | ||
|
||
|
||
class DatabaseManager(): | ||
db: Optional[Database] = None | ||
|
||
def config(self, c: dict): | ||
config = DatabaseConfig.model_validate(c) | ||
self.db = Database(config) | ||
|
||
def execute(self, query): | ||
"""Execute a SQL query or non-query and return the result. | ||
If the query is a SELECT statement, return the result as a DataFrame. | ||
For other statements (INSERT, UPDATE, DELETE), execute the statement and return the number of affected rows. | ||
""" | ||
if self.db is not None: | ||
return self.db.execute(query) | ||
else: | ||
raise Exception( | ||
'Can not execute sql before database config set') | ||
|
||
|
||
db = DatabaseManager() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from IPython.core.interactiveshell import InteractiveShell | ||
|
||
|
||
def store_exception(shell: InteractiveShell, etype: type, evalue, tb, tb_offset=None): | ||
# A structured traceback (a list of strings) or None | ||
|
||
if issubclass(etype, SyntaxError): | ||
# Disable ANSI color strings | ||
shell.SyntaxTB.color_toggle() | ||
# Don't display a stacktrace because a syntax error has no stacktrace | ||
stb = shell.SyntaxTB.structured_traceback(etype, evalue, []) | ||
stb_text = shell.SyntaxTB.stb2text(stb) | ||
# Re-enable ANSI color strings | ||
shell.SyntaxTB.color_toggle() | ||
else: | ||
# Disable ANSI color strings | ||
shell.InteractiveTB.color_toggle() | ||
stb = shell.InteractiveTB.structured_traceback( | ||
etype, evalue, tb, tb_offset=tb_offset | ||
) | ||
stb_text = shell.InteractiveTB.stb2text(stb) | ||
# Re-enable ANSI color strings | ||
shell.InteractiveTB.color_toggle() | ||
|
||
etraceback = shell.showtraceback() | ||
|
||
styled_exception = str(stb_text) | ||
|
||
prompt_number = shell.execution_count | ||
err = shell.user_ns.get("Err", {}) | ||
err[prompt_number] = styled_exception | ||
shell.user_ns["Err"] = err | ||
|
||
# Return | ||
return etraceback |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from IPython.core.interactiveshell import InteractiveShell | ||
from .exception import store_exception | ||
from .sql_magic import SQLMagic | ||
|
||
|
||
def load_ipython_extension(ipython: InteractiveShell): | ||
ipython.register_magics(SQLMagic) | ||
ipython.set_custom_exc((BaseException,), store_exception) | ||
|
||
|
||
def unload_ipython_extension(ipython: InteractiveShell): | ||
ipython.set_custom_exc((BaseException,), ipython.CustomTB) | ||
|
||
|
||
def _load_jupyter_server_extension(ipython): | ||
"""Load the Jupyter server extension. | ||
Parameters | ||
---------- | ||
ipython: :class:`jupyter_client.ioloop.IOLoopKernelManager` | ||
Jupyter kernel manager instance. | ||
""" | ||
load_ipython_extension(ipython) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import base64 | ||
import json | ||
from IPython.core.magic import Magics, magics_class, line_cell_magic | ||
from .database import db | ||
|
||
|
||
def is_ipython() -> bool: | ||
""" | ||
Check if interface is launching from iPython | ||
:return is_ipython (bool): True or False | ||
""" | ||
is_ipython = False | ||
try: # Check if running interactively using ipython. | ||
from IPython import get_ipython | ||
|
||
if get_ipython() is not None: | ||
is_ipython = True | ||
except (ImportError, NameError): | ||
pass | ||
return is_ipython | ||
|
||
|
||
def preprocessing_line(line, local_ns): | ||
try: | ||
user_input = str(base64.decodebytes(line.encode()), "utf-8") | ||
# 将JSON字符串解析成Python对象 | ||
json_obj = json.loads(user_input) | ||
content = json_obj.get("sql_script") | ||
# 替换变量 | ||
if content: | ||
for key, value in local_ns.items(): | ||
if key and not key.startswith("_"): | ||
content = content.replace("{{" + key + "}}", str(value)) | ||
json_obj["sql_script"] = content | ||
return json_obj | ||
except Exception as e: | ||
raise Exception("preprocess error", e) | ||
|
||
|
||
def preprocessing_cell(cell, local_ns): | ||
try: | ||
# 将JSON字符串解析成Python对象 | ||
json_obj = json.loads(cell) | ||
content = json_obj.get("sql_script") | ||
# 替换变量 | ||
if content: | ||
for key, value in local_ns.items(): | ||
if key and not key.startswith("_"): | ||
content = content.replace("{{" + key + "}}", str(value)) | ||
json_obj["sql_script"] = content | ||
return json_obj | ||
except Exception as e: | ||
raise Exception("preprocess error", e) | ||
|
||
|
||
@magics_class | ||
class SQLMagic(Magics): | ||
""" | ||
%%prompt | ||
{"result_variable":"custom_variable_name","sql_script":"SELECT 1"} | ||
""" | ||
|
||
def __init__(self, shell=None): | ||
super(SQLMagic, self).__init__(shell) | ||
|
||
@line_cell_magic | ||
def sql(self, line="", cell=None): | ||
local_ns = self.shell.user_ns # type: ignore | ||
if cell is None: | ||
args = preprocessing_line(line, local_ns) | ||
else: | ||
args = preprocessing_cell(cell, local_ns) | ||
|
||
result_variable: str = args.get("result_variable") | ||
sql_script: str = args.get("sql_script") | ||
|
||
if sql_script is None or sql_script == "": | ||
raise Exception("Invalid sql script!") | ||
|
||
res = db.execute(sql_script) | ||
|
||
# Set variable | ||
try: | ||
if result_variable is None or result_variable == "": | ||
return | ||
if not result_variable.isidentifier(): | ||
raise Exception( | ||
'Invalid variable name "{}".'.format(result_variable) | ||
) | ||
else: | ||
local_ns[result_variable] = res | ||
except Exception as e: | ||
raise Exception("set variable error", e) | ||
|
||
if is_ipython(): | ||
from IPython.display import display | ||
display(res) |
Oops, something went wrong.