Skip to content

Commit

Permalink
优化SQL解析功能
Browse files Browse the repository at this point in the history
  • Loading branch information
welliamcao committed May 5, 2020
1 parent bc24aa9 commit d49b782
Show file tree
Hide file tree
Showing 7 changed files with 434 additions and 62 deletions.
61 changes: 41 additions & 20 deletions apps/databases/mysql/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from OpsManage.settings import config
from databases.models import *
from utils import base
from utils.sqlparse import sql_parse
# from threading import Thread
import threading

Expand Down Expand Up @@ -106,34 +107,54 @@ def _check_user_db_tables(self, db):
pass
return []

def __check_sql_parse(self, sql, allow_sql, dbname):
try:
sql = sql.split(' ')
sqlCmd, sqlCmds = sql[0].upper().strip(),(sql[0]+'_'+sql[1]).upper().replace(";","").strip()
except Exception as ex:
return "解析SQL失败: {ex}".format(ex=ex)
def _extract_keyword_from_sql(self, sql):
return sql_parse.extract_sql_keyword(sql)

def _extract_table_name_from_sql(self ,sql):
schema = []
tables = []
for ds in sql_parse.extract_tables(sql):

if ds.schema and ds.schema not in schema:
schema.append(ds.schema)

if ds.name and ds.name not in tables:
tables.append(ds.name)

if len(schema) > 0:
return "不支持跨数据库类型SQL"

return tables

def __check_sql_parse(self, sql, allow_sql):
#查询用户是不是有授权表
grant_tables = self._check_user_db_tables(self.scope['url_route']['kwargs']['id'])

#提取SQL中的表名
extract_table = base.extract_table_name_from_sql(" ".join(sql))

if extract_table:
if grant_tables:
for tb in extract_table:
if tb.find('.') >= 0:
db,tb = tb.split('.')[0],tb.split('.')[1]
if db != dbname:return "不支持跨库查询"
if tb not in grant_tables:return "操作的表未授权"
extract_table = self._extract_table_name_from_sql(sql)

if isinstance(extract_table, list) and grant_tables:

for tb in extract_table:
if tb not in grant_tables:
return "操作的表未授权"

elif isinstance(extract_table, str):
return extract_table

else:#如果提交的SQL里面没有包含授权的表,就检查SQL类型是否授权
#查询用户授权的SQL类型
grant_sql = self._check_user_db_sql(self.scope['url_route']['kwargs']['id'])

if sqlCmd.upper() in grant_sql or sqlCmds in grant_sql:return True

if sqlCmd not in allow_sql: return 'SQL类型不支持'

sql_type, _first_token , keywords = self._extract_keyword_from_sql(sql)

if len(keywords) > 1:
if keywords[0] + '_' + keywords[1] in grant_sql:
return True
# print(_first_token, keywords, grant_sql, allow_sql)

if _first_token in allow_sql: return True

return "SQL未授权, 联系管理员授权"

return True
Expand Down Expand Up @@ -199,7 +220,7 @@ def receive(self, text_data=None, bytes_data=None):
def _check_sql(self, text_data):
if len(self.sql) >= 2:
if text_data == '\r' and (self.sql[-1]==';' or self.sql[-2:]=='\G'):
sql_parse = self.__check_sql_parse(self.sql, self.dml_sql + self.ddl_sql + self.dql_sql, self.db.get("db_name"))
sql_parse = self.__check_sql_parse(self.sql, self.dml_sql + self.ddl_sql + self.dql_sql)
try:
if isinstance(sql_parse, str):
self.status = False
Expand Down
65 changes: 43 additions & 22 deletions dao/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import time, json
from jinja2 import Template
from databases.models import *
from utils import base
from utils.logger import logger
from .assets import AssetsBase
from asset.models import *
from datetime import datetime
from databases.service.mysql_base import MySQLBase
from utils import base
from utils.sqlparse import sql_parse
from utils.mysql.binlog2sql import Binlog2sql
from utils.mysql.const import SQL_PERMISSIONS,SQL_DICT_HTML
from apps.tasks.celery_sql import record_exec_sql, export_table, parse_binlog
Expand Down Expand Up @@ -236,7 +237,7 @@ def update_user_server_db(self, request, db_server, user):

class DBManage(AssetsBase):
dml_sql = ["INSERT","UPDATE","DELETE"]
dql_sql = ["SELECT","SHOW","DESC","EXPLAIN"]
dql_sql = ["SELECT","DESC","EXPLAIN"]
ddl_sql = ["CREATE","DROP","ALTER","TRUNCATE"]

def __init__(self):
Expand Down Expand Up @@ -285,35 +286,56 @@ def __check_user_db_sql(self,request):

return []

def __check_sql_parse(self, request, allow_sql, dbname, sql):
try:
sql = sql.split(' ')
sqlCmd, sqlCmds = sql[0].upper(),(sql[0]+'_'+sql[1]).upper().replace(";","")
except Exception as ex:
logger.error(msg="解析SQL失败: {ex}".format(ex=ex))
return '解析SQL失败'
def __extract_keyword_from_sql(self, sql):
return sql_parse.extract_sql_keyword(sql)

def __extract_table_name_from_sql(self ,sql):
schema = []
tables = []
for ds in sql_parse.extract_tables(sql):

if ds.schema and ds.schema not in schema:
schema.append(ds.schema)

if ds.name and ds.name not in tables:
tables.append(ds.name)

if len(schema) > 0:
return "不支持跨数据库类型SQL"

return tables

def __check_sql_parse(self, request, allow_sql, sql, read_only=True):
#查询用户是不是有授权表
grant_tables = self.__check_user_db_tables(request)

#提取SQL中的表名
extract_table = base.extract_table_name_from_sql(" ".join(sql))
extract_table = self.__extract_table_name_from_sql(sql)

if isinstance(extract_table, list) and grant_tables:

if extract_table:
if grant_tables:
for tb in extract_table:
if tb.find('.') >= 0:
db,tb = tb.split('.')[0],tb.split('.')[1]
if db != dbname:return "不支持跨库查询"
if tb not in grant_tables:return "操作的表未授权"
for tb in extract_table:
if tb not in grant_tables:
return "操作的表未授权"

elif isinstance(extract_table, str):
return extract_table

else:#如果提交的SQL里面没有包含授权的表,就检查SQL类型是否授权
#查询用户授权的SQL类型
grant_sql = self.__check_user_db_sql(request)

if sqlCmd.upper() in grant_sql or sqlCmds in grant_sql:return True
sql_type, _first_token , keywords = self.__extract_keyword_from_sql(sql)

if len(keywords) > 1:
if keywords[0] + '_' + keywords[1] in grant_sql:
return True
# print(_first_token, keywords, grant_sql, allow_sql)

if sqlCmd not in allow_sql: return 'SQL类型不支持'
if read_only and _first_token == 'SELECT' and 'INTO' in keywords:return "当前操作,不允许写入"

if _first_token in allow_sql: return True

return "SQL未授权, 联系管理员授权"

return True
Expand Down Expand Up @@ -351,8 +373,7 @@ def exec_sql(self, request):
for sql in sql_list:
stime = int(time.time())
sql = sql.strip('\n') + ';'
sql_parse = self.__check_sql_parse(request, sql=sql, allow_sql=self.dml_sql + self.ddl_sql + self.dql_sql,
dbname=dbServer.get('db_name'))
sql_parse = self.__check_sql_parse(request, sql=sql, allow_sql=self.dml_sql + self.ddl_sql + self.dql_sql, read_only=False)

if not isinstance(sql_parse, str):
result = self.__get_db_server(dbServer).execute(sql, 1000)
Expand Down Expand Up @@ -383,7 +404,7 @@ def query_sql(self, request):
for sql in sql_list:
stime = int(time.time())
sql = sql.strip('\n') + ';'
sql_parse = self.__check_sql_parse(request, sql=sql, allow_sql=self.dql_sql, dbname=dbServer.get('db_name'))
sql_parse = self.__check_sql_parse(request, sql=sql, allow_sql=self.dql_sql, read_only=True)
if not isinstance(sql_parse, str):
result = self.__get_db_server(dbServer).queryMany(sql, 1000)
if not isinstance(result, str):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ redis==2.10.6
requests==2.20.0
six==1.11.0
smmap2==2.0.5
sqlparse==0.3.1
tablib==0.12.1
Twisted==18.9.0
txaio==18.8.1
Expand Down
20 changes: 0 additions & 20 deletions utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,6 @@
from functools import wraps
import ply.lex as lex, re

def extract_table_name_from_sql(sql_str):
q = re.sub(r"/\*[^*]*\*+(?:[^*/][^*]*\*+)*/", "", sql_str)

lines = [line for line in q.splitlines() if not re.match("^\s*(--|#)", line)]

q = " ".join([re.split("--|#", line)[0] for line in lines])

tokens = re.split(r"[\s)(;]+", q)

result = []
get_next = False
for token in tokens:
if get_next:
if token.lower() not in ["", "select"]:
result.append(token)
get_next = False
get_next = token.lower() in ["from", "join","into","table","update","desc"]

return result

def method_decorator_adaptor(adapt_to, *decorator_args, **decorator_kwargs):
def decorator_outer(func):
@wraps(func)
Expand Down
2 changes: 2 additions & 0 deletions utils/sqlparse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#This function is borrowed from pip install clickhouse_cli
from .parse import sql_parse
Loading

0 comments on commit d49b782

Please sign in to comment.