Skip to content

Commit

Permalink
feat(sql): sql magic, support postgresql/mysql/sqlite
Browse files Browse the repository at this point in the history
  • Loading branch information
BroKun authored and sunshinesmilelk committed Sep 13, 2024
1 parent affdd0e commit 3ad3b2f
Show file tree
Hide file tree
Showing 12 changed files with 370 additions and 6 deletions.
10 changes: 10 additions & 0 deletions libro-sql/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# python generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info

# venv
.venv
1 change: 1 addition & 0 deletions libro-sql/.python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.10
30 changes: 30 additions & 0 deletions libro-sql/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# libro-sql

# 使用

## 加载

```shell
%load_ext libro_sql
```

# 设置

```python
from libro_sql.database import db
db.config({
'db_type': '',
'username': '',
'password': '',
'host': '',
'port': 5432,
'database': ''
})
```

# 执行

```python
%%sql
{"result_variable":"a", "sql_script":"select 1"}
```
32 changes: 32 additions & 0 deletions libro-sql/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
[project]
name = "libro-sql"
version = "0.1.2"
description = "libro flow"
authors = [
{ name = "brokun", email = "[email protected]" },
{ name = "sunshinesmilelk", email = "[email protected]" },
]
dependencies = [
"ipython>=7.34.0",
"sqlalchemy>=2.0.34",
"pandas>=2.2.2",
"pydantic>=2.9.1",
"psycopg2-binary>=2.9.9",
"pymysql>=1.1.1",
]
dev-dependencies = []
readme = "README.md"
requires-python = ">= 3.10"

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.rye]
managed = true

[tool.hatch.metadata]
allow-direct-references = true

[tool.hatch.build.targets.wheel]
packages = ["src/libro_sql"]
7 changes: 7 additions & 0 deletions libro-sql/src/libro_sql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ._version import __version__

from .extensions import (
load_ipython_extension,
unload_ipython_extension,
_load_jupyter_server_extension,
)
2 changes: 2 additions & 0 deletions libro-sql/src/libro_sql/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
__version__ = "0.1.2"
94 changes: 94 additions & 0 deletions libro-sql/src/libro_sql/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@

from typing import Optional
from pydantic import BaseModel
from sqlalchemy import create_engine, text
from sqlalchemy.exc import SQLAlchemyError
import pandas as pd


class DatabaseConfig(BaseModel):
db_type: str
username: str
password: str
host: str
port: int
database: str


class Database:
config: DatabaseConfig

def __init__(self, config: DatabaseConfig):
self.config = config
self.engine = self._create_engine()

def _create_engine(self):
"""Create the SQLAlchemy engine based on the database type."""
config = self.config
try:
if config.db_type == 'postgresql':
engine = create_engine(
f'postgresql+psycopg2://{config.username}:{config.password}@{config.host}:{config.port}/{config.database}')
elif config.db_type == 'mysql':
engine = create_engine(
f'mysql+pymysql://{config.username}:{config.password}@{config.host}:{config.port}/{config.database}')
elif config.db_type == 'sqlite':
engine = create_engine(f'sqlite:///{config.database}')
else:
raise ValueError(
f"Unsupported database type: {config.db_type}")
return engine
except Exception as e:
print(f"Error creating engine: {e}")
raise

def execute(self, query):
"""Execute a SQL query or non-query and return the result.
If the query is a SELECT statement, return the result as a DataFrame.
For other statements (INSERT, UPDATE, DELETE), execute the statement and return the number of affected rows.
"""
with self.engine.connect() as connection:
try:
result = connection.execute(text(query))
if result.returns_rows:
# Fetch all rows and construct DataFrame with column names
rows = result.fetchall()
if rows:
# Debug: Print fetched rows
df = pd.DataFrame(rows, columns=result.keys())
else:
df = pd.DataFrame() # Return empty DataFrame if no rows
return df
else:
if result.rowcount is not None:
connection.commit()
return result.rowcount
else:
return result
except SQLAlchemyError as e:
print(f"Error executing query: {e}")
raise


class DatabaseManager():
db: Optional[Database] = None

def config(self, c: dict):
config = DatabaseConfig.model_validate(c)
self.db = Database(config)

def execute(self, query):
"""Execute a SQL query or non-query and return the result.
If the query is a SELECT statement, return the result as a DataFrame.
For other statements (INSERT, UPDATE, DELETE), execute the statement and return the number of affected rows.
"""
if self.db is not None:
return self.db.execute(query)
else:
raise Exception(
'Can not execute sql before database config set')


db = DatabaseManager()
35 changes: 35 additions & 0 deletions libro-sql/src/libro_sql/exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from IPython.core.interactiveshell import InteractiveShell


def store_exception(shell: InteractiveShell, etype: type, evalue, tb, tb_offset=None):
# A structured traceback (a list of strings) or None

if issubclass(etype, SyntaxError):
# Disable ANSI color strings
shell.SyntaxTB.color_toggle()
# Don't display a stacktrace because a syntax error has no stacktrace
stb = shell.SyntaxTB.structured_traceback(etype, evalue, [])
stb_text = shell.SyntaxTB.stb2text(stb)
# Re-enable ANSI color strings
shell.SyntaxTB.color_toggle()
else:
# Disable ANSI color strings
shell.InteractiveTB.color_toggle()
stb = shell.InteractiveTB.structured_traceback(
etype, evalue, tb, tb_offset=tb_offset
)
stb_text = shell.InteractiveTB.stb2text(stb)
# Re-enable ANSI color strings
shell.InteractiveTB.color_toggle()

etraceback = shell.showtraceback()

styled_exception = str(stb_text)

prompt_number = shell.execution_count
err = shell.user_ns.get("Err", {})
err[prompt_number] = styled_exception
shell.user_ns["Err"] = err

# Return
return etraceback
22 changes: 22 additions & 0 deletions libro-sql/src/libro_sql/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from IPython.core.interactiveshell import InteractiveShell
from .exception import store_exception
from .sql_magic import SQLMagic


def load_ipython_extension(ipython: InteractiveShell):
ipython.register_magics(SQLMagic)
ipython.set_custom_exc((BaseException,), store_exception)


def unload_ipython_extension(ipython: InteractiveShell):
ipython.set_custom_exc((BaseException,), ipython.CustomTB)


def _load_jupyter_server_extension(ipython):
"""Load the Jupyter server extension.
Parameters
----------
ipython: :class:`jupyter_client.ioloop.IOLoopKernelManager`
Jupyter kernel manager instance.
"""
load_ipython_extension(ipython)
99 changes: 99 additions & 0 deletions libro-sql/src/libro_sql/sql_magic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# -*- coding: utf-8 -*-

import base64
import json
from IPython.core.magic import Magics, magics_class, line_cell_magic
from .database import db


def is_ipython() -> bool:
"""
Check if interface is launching from iPython
:return is_ipython (bool): True or False
"""
is_ipython = False
try: # Check if running interactively using ipython.
from IPython import get_ipython

if get_ipython() is not None:
is_ipython = True
except (ImportError, NameError):
pass
return is_ipython


def preprocessing_line(line, local_ns):
try:
user_input = str(base64.decodebytes(line.encode()), "utf-8")
# 将JSON字符串解析成Python对象
json_obj = json.loads(user_input)
content = json_obj.get("sql_script")
# 替换变量
if content:
for key, value in local_ns.items():
if key and not key.startswith("_"):
content = content.replace("{{" + key + "}}", str(value))
json_obj["sql_script"] = content
return json_obj
except Exception as e:
raise Exception("preprocess error", e)


def preprocessing_cell(cell, local_ns):
try:
# 将JSON字符串解析成Python对象
json_obj = json.loads(cell)
content = json_obj.get("sql_script")
# 替换变量
if content:
for key, value in local_ns.items():
if key and not key.startswith("_"):
content = content.replace("{{" + key + "}}", str(value))
json_obj["sql_script"] = content
return json_obj
except Exception as e:
raise Exception("preprocess error", e)


@magics_class
class SQLMagic(Magics):
"""
%%prompt
{"result_variable":"custom_variable_name","sql_script":"SELECT 1"}
"""

def __init__(self, shell=None):
super(SQLMagic, self).__init__(shell)

@line_cell_magic
def sql(self, line="", cell=None):
local_ns = self.shell.user_ns # type: ignore
if cell is None:
args = preprocessing_line(line, local_ns)
else:
args = preprocessing_cell(cell, local_ns)

result_variable: str = args.get("result_variable")
sql_script: str = args.get("sql_script")

if sql_script is None or sql_script == "":
raise Exception("Invalid sql script!")

res = db.execute(sql_script)

# Set variable
try:
if result_variable is None or result_variable == "":
return
if not result_variable.isidentifier():
raise Exception(
'Invalid variable name "{}".'.format(result_variable)
)
else:
local_ns[result_variable] = res
except Exception as e:
raise Exception("set variable error", e)

if is_ipython():
from IPython.display import display
display(res)
Loading

0 comments on commit 3ad3b2f

Please sign in to comment.