From d49b782e640e6f6a073c4df439f7930ed1d527f8 Mon Sep 17 00:00:00 2001 From: "Welliam.Cao" <303350019@qq.com> Date: Tue, 5 May 2020 23:00:18 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96SQL=E8=A7=A3=E6=9E=90?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/databases/mysql/consumers.py | 61 +++++--- dao/mysql.py | 65 +++++--- requirements.txt | 1 + utils/base.py | 20 --- utils/sqlparse/__init__.py | 2 + utils/sqlparse/parse.py | 246 ++++++++++++++++++++++++++++++ utils/sqlparse/patch.py | 101 ++++++++++++ 7 files changed, 434 insertions(+), 62 deletions(-) create mode 100644 utils/sqlparse/__init__.py create mode 100644 utils/sqlparse/parse.py create mode 100644 utils/sqlparse/patch.py diff --git a/apps/databases/mysql/consumers.py b/apps/databases/mysql/consumers.py index 1b51004c..880278df 100644 --- a/apps/databases/mysql/consumers.py +++ b/apps/databases/mysql/consumers.py @@ -6,6 +6,7 @@ from OpsManage.settings import config from databases.models import * from utils import base +from utils.sqlparse import sql_parse # from threading import Thread import threading @@ -106,34 +107,54 @@ def _check_user_db_tables(self, db): pass return [] - def __check_sql_parse(self, sql, allow_sql, dbname): - try: - sql = sql.split(' ') - sqlCmd, sqlCmds = sql[0].upper().strip(),(sql[0]+'_'+sql[1]).upper().replace(";","").strip() - except Exception as ex: - return "解析SQL失败: {ex}".format(ex=ex) + def _extract_keyword_from_sql(self, sql): + return sql_parse.extract_sql_keyword(sql) + + def _extract_table_name_from_sql(self ,sql): + schema = [] + tables = [] + for ds in sql_parse.extract_tables(sql): + + if ds.schema and ds.schema not in schema: + schema.append(ds.schema) + + if ds.name and ds.name not in tables: + tables.append(ds.name) + + if len(schema) > 0: + return "不支持跨数据库类型SQL" + return tables + + def __check_sql_parse(self, sql, allow_sql): #查询用户是不是有授权表 grant_tables = self._check_user_db_tables(self.scope['url_route']['kwargs']['id']) #提取SQL中的表名 - extract_table = base.extract_table_name_from_sql(" ".join(sql)) - - if extract_table: - if grant_tables: - for tb in extract_table: - if tb.find('.') >= 0: - db,tb = tb.split('.')[0],tb.split('.')[1] - if db != dbname:return "不支持跨库查询" - if tb not in grant_tables:return "操作的表未授权" + extract_table = self._extract_table_name_from_sql(sql) + + if isinstance(extract_table, list) and grant_tables: + + for tb in extract_table: + if tb not in grant_tables: + return "操作的表未授权" + + elif isinstance(extract_table, str): + return extract_table + else:#如果提交的SQL里面没有包含授权的表,就检查SQL类型是否授权 #查询用户授权的SQL类型 grant_sql = self._check_user_db_sql(self.scope['url_route']['kwargs']['id']) - - if sqlCmd.upper() in grant_sql or sqlCmds in grant_sql:return True - - if sqlCmd not in allow_sql: return 'SQL类型不支持' + sql_type, _first_token , keywords = self._extract_keyword_from_sql(sql) + + if len(keywords) > 1: + if keywords[0] + '_' + keywords[1] in grant_sql: + return True +# print(_first_token, keywords, grant_sql, allow_sql) + + if _first_token in allow_sql: return True + return "SQL未授权, 联系管理员授权" return True @@ -199,7 +220,7 @@ def receive(self, text_data=None, bytes_data=None): def _check_sql(self, text_data): if len(self.sql) >= 2: if text_data == '\r' and (self.sql[-1]==';' or self.sql[-2:]=='\G'): - sql_parse = self.__check_sql_parse(self.sql, self.dml_sql + self.ddl_sql + self.dql_sql, self.db.get("db_name")) + sql_parse = self.__check_sql_parse(self.sql, self.dml_sql + self.ddl_sql + self.dql_sql) try: if isinstance(sql_parse, str): self.status = False diff --git a/dao/mysql.py b/dao/mysql.py index 6cf1c247..f9e36ac8 100644 --- a/dao/mysql.py +++ b/dao/mysql.py @@ -4,12 +4,13 @@ import time, json from jinja2 import Template from databases.models import * +from utils import base from utils.logger import logger from .assets import AssetsBase from asset.models import * from datetime import datetime from databases.service.mysql_base import MySQLBase -from utils import base +from utils.sqlparse import sql_parse from utils.mysql.binlog2sql import Binlog2sql from utils.mysql.const import SQL_PERMISSIONS,SQL_DICT_HTML from apps.tasks.celery_sql import record_exec_sql, export_table, parse_binlog @@ -236,7 +237,7 @@ def update_user_server_db(self, request, db_server, user): class DBManage(AssetsBase): dml_sql = ["INSERT","UPDATE","DELETE"] - dql_sql = ["SELECT","SHOW","DESC","EXPLAIN"] + dql_sql = ["SELECT","DESC","EXPLAIN"] ddl_sql = ["CREATE","DROP","ALTER","TRUNCATE"] def __init__(self): @@ -285,35 +286,56 @@ def __check_user_db_sql(self,request): return [] - def __check_sql_parse(self, request, allow_sql, dbname, sql): - try: - sql = sql.split(' ') - sqlCmd, sqlCmds = sql[0].upper(),(sql[0]+'_'+sql[1]).upper().replace(";","") - except Exception as ex: - logger.error(msg="解析SQL失败: {ex}".format(ex=ex)) - return '解析SQL失败' + def __extract_keyword_from_sql(self, sql): + return sql_parse.extract_sql_keyword(sql) + + def __extract_table_name_from_sql(self ,sql): + schema = [] + tables = [] + for ds in sql_parse.extract_tables(sql): + + if ds.schema and ds.schema not in schema: + schema.append(ds.schema) + + if ds.name and ds.name not in tables: + tables.append(ds.name) + + if len(schema) > 0: + return "不支持跨数据库类型SQL" + return tables + + def __check_sql_parse(self, request, allow_sql, sql, read_only=True): #查询用户是不是有授权表 grant_tables = self.__check_user_db_tables(request) #提取SQL中的表名 - extract_table = base.extract_table_name_from_sql(" ".join(sql)) + extract_table = self.__extract_table_name_from_sql(sql) + + if isinstance(extract_table, list) and grant_tables: - if extract_table: - if grant_tables: - for tb in extract_table: - if tb.find('.') >= 0: - db,tb = tb.split('.')[0],tb.split('.')[1] - if db != dbname:return "不支持跨库查询" - if tb not in grant_tables:return "操作的表未授权" + for tb in extract_table: + if tb not in grant_tables: + return "操作的表未授权" + + elif isinstance(extract_table, str): + return extract_table + else:#如果提交的SQL里面没有包含授权的表,就检查SQL类型是否授权 #查询用户授权的SQL类型 grant_sql = self.__check_user_db_sql(request) - if sqlCmd.upper() in grant_sql or sqlCmds in grant_sql:return True + sql_type, _first_token , keywords = self.__extract_keyword_from_sql(sql) + + if len(keywords) > 1: + if keywords[0] + '_' + keywords[1] in grant_sql: + return True +# print(_first_token, keywords, grant_sql, allow_sql) - if sqlCmd not in allow_sql: return 'SQL类型不支持' + if read_only and _first_token == 'SELECT' and 'INTO' in keywords:return "当前操作,不允许写入" + if _first_token in allow_sql: return True + return "SQL未授权, 联系管理员授权" return True @@ -351,8 +373,7 @@ def exec_sql(self, request): for sql in sql_list: stime = int(time.time()) sql = sql.strip('\n') + ';' - sql_parse = self.__check_sql_parse(request, sql=sql, allow_sql=self.dml_sql + self.ddl_sql + self.dql_sql, - dbname=dbServer.get('db_name')) + sql_parse = self.__check_sql_parse(request, sql=sql, allow_sql=self.dml_sql + self.ddl_sql + self.dql_sql, read_only=False) if not isinstance(sql_parse, str): result = self.__get_db_server(dbServer).execute(sql, 1000) @@ -383,7 +404,7 @@ def query_sql(self, request): for sql in sql_list: stime = int(time.time()) sql = sql.strip('\n') + ';' - sql_parse = self.__check_sql_parse(request, sql=sql, allow_sql=self.dql_sql, dbname=dbServer.get('db_name')) + sql_parse = self.__check_sql_parse(request, sql=sql, allow_sql=self.dql_sql, read_only=True) if not isinstance(sql_parse, str): result = self.__get_db_server(dbServer).queryMany(sql, 1000) if not isinstance(result, str): diff --git a/requirements.txt b/requirements.txt index c43fbf72..d8ccbc5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -74,6 +74,7 @@ redis==2.10.6 requests==2.20.0 six==1.11.0 smmap2==2.0.5 +sqlparse==0.3.1 tablib==0.12.1 Twisted==18.9.0 txaio==18.8.1 diff --git a/utils/base.py b/utils/base.py index e7571039..eec47439 100644 --- a/utils/base.py +++ b/utils/base.py @@ -14,26 +14,6 @@ from functools import wraps import ply.lex as lex, re -def extract_table_name_from_sql(sql_str): - q = re.sub(r"/\*[^*]*\*+(?:[^*/][^*]*\*+)*/", "", sql_str) - - lines = [line for line in q.splitlines() if not re.match("^\s*(--|#)", line)] - - q = " ".join([re.split("--|#", line)[0] for line in lines]) - - tokens = re.split(r"[\s)(;]+", q) - - result = [] - get_next = False - for token in tokens: - if get_next: - if token.lower() not in ["", "select"]: - result.append(token) - get_next = False - get_next = token.lower() in ["from", "join","into","table","update","desc"] - - return result - def method_decorator_adaptor(adapt_to, *decorator_args, **decorator_kwargs): def decorator_outer(func): @wraps(func) diff --git a/utils/sqlparse/__init__.py b/utils/sqlparse/__init__.py new file mode 100644 index 00000000..8652599b --- /dev/null +++ b/utils/sqlparse/__init__.py @@ -0,0 +1,2 @@ +#This function is borrowed from pip install clickhouse_cli +from .parse import sql_parse \ No newline at end of file diff --git a/utils/sqlparse/parse.py b/utils/sqlparse/parse.py new file mode 100644 index 00000000..38e33d09 --- /dev/null +++ b/utils/sqlparse/parse.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python +# _#_ coding:utf-8 _*_ +import sqlparse +from collections import namedtuple +from sqlparse.sql import IdentifierList, Identifier, Function +from sqlparse.tokens import Keyword, DML, DDL, Punctuation +from .patch import KEYWORDS + +TableReference = namedtuple('TableReference', ['schema', 'name', 'alias', 'is_function']) +TableReference.ref = property( + lambda self: self.alias or ( + self.name if self.name.islower() or self.name[0] == '"' else '"' + self.name + '"') +) + +sqlparse.keywords.KEYWORDS = KEYWORDS + +class SQLParse: + + def is_subselect(self, parsed): + if not parsed.is_group: + return False + for item in parsed.tokens: + if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT', 'CREATE'): + return True + return False + + def _identifiers(self, tok): + if isinstance(tok, IdentifierList): + for t in tok.get_identifiers(): + if isinstance(t, Identifier): + yield t + elif isinstance(tok, Identifier): + yield tok + + def extract_column_names(self, sql): + parsed = sqlparse.parse(sql)[0] + if not parsed: + return () + + idx, tok = parsed.token_next_by(t=DML) + tok_val = tok and tok.value.lower() + + if tok_val in ('insert', 'update', 'delete'): + idx, tok = parsed.token_next_by(idx, (Keyword, 'returning')) + elif not tok_val == 'select': + return () + + idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True) + return tuple(t.get_name() for t in self._identifiers(tok)) + + def _identifier_is_function(self, identifier): + return any(isinstance(t, Function) for t in identifier.tokens) + + + def extract_from_part(self, parsed, stop_at_punctuation=True): + tbl_prefix_seen = False + for item in parsed.tokens: + if tbl_prefix_seen: + if self.is_subselect(item): + for x in self.extract_from_part(item, stop_at_punctuation): + yield x + elif stop_at_punctuation and item.ttype is Punctuation: + raise StopIteration + + elif item.ttype is Keyword and ( + not item.value.upper() == 'FROM') and ( + not item.value.upper().endswith('JOIN')): + tbl_prefix_seen = False + else: + yield item + elif item.ttype is Keyword or item.ttype is Keyword.DML: + item_val = item.value.upper() + if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or + item_val.endswith('JOIN')): + tbl_prefix_seen = True + + elif isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + if (identifier.ttype is Keyword and + identifier.value.upper() == 'FROM'): + tbl_prefix_seen = True + break + + + def extract_table_identifiers(self, token_stream, allow_functions=True): + + def parse_identifier(item): + name = item.get_real_name() + schema_name = item.get_parent_name() + alias = item.get_alias() + if not name: + schema_name = None + name = item.get_name() + alias = alias or name + schema_quoted = schema_name and item.value[0] == '"' + if schema_name and not schema_quoted: + schema_name = schema_name.lower() + quote_count = item.value.count('"') + name_quoted = quote_count > 2 or (quote_count and not schema_quoted) + alias_quoted = alias and item.value[-1] == '"' + if alias_quoted or name_quoted and not alias and name.islower(): + alias = '"' + (alias or name) + '"' + if name and not name_quoted and not name.islower(): + if not alias: + alias = name + name = name.lower() + return schema_name, name, alias + + for item in token_stream: + if isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + try: + schema_name = identifier.get_parent_name() + real_name = identifier.get_real_name() + is_function = (allow_functions and + self._identifier_is_function(identifier)) + except AttributeError: + continue + if real_name: + yield TableReference(schema_name, real_name, + identifier.get_alias(), is_function) + elif isinstance(item, Identifier): + schema_name, real_name, alias = parse_identifier(item) + is_function = allow_functions and self._identifier_is_function(item) + + yield TableReference(schema_name, real_name, alias, is_function) + elif isinstance(item, Function): + schema_name, real_name, alias = parse_identifier(item) + yield TableReference(None, real_name, alias, allow_functions) + + + def extract_tables(self,sql): + parsed = sqlparse.parse(sql) + if not parsed: + return () + + insert_stmt = parsed[0].token_first().value.lower() == 'insert' + stream = self.extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) + + identifiers = self.extract_table_identifiers(stream, allow_functions=not insert_stmt) + return tuple(i for i in identifiers if i.name) + + def extract_sql_keyword(self, sql): + sql_type = 'unknown' + keywords = [] + parsed = sqlparse.parse(sql) + if not parsed: + return () + _first_token = sqlparse.sql.Statement(parsed[0].tokens).token_first().value.upper() + for item in parsed[0].tokens: + if item.ttype is DML: + sql_type = 'dml' + elif item.ttype is DDL: + sql_type = 'ddl' + key = item.value.upper() + if (item.ttype is Keyword or \ + item.ttype is Keyword.DDL or item.ttype is Keyword.DML) and key not in keywords: + keywords.append(item.value.upper()) + return sql_type, _first_token , keywords + +sql_parse = SQLParse() + +if __name__ == "__main__": + sql = """ + SELECT + a.*, f.ORG_NAME DEPT_NAME, + IFNULL(d.CONT_COUNT, 0) SIGN_CONT_COUNT, + IFNULL(d.TOTAL_PRICE, 0) SIGN_CONT_MONEY, + IFNULL(c.CONT_COUNT, 0) SIGN_ARRI_CONT_COUNT, + IFNULL(c.TOTAL_PRICE, 0) SIGN_ARRI_CONT_MONEY, + IFNULL(b.CONT_COUNT, 0) TOTAL_ARRI_CONT_COUNT, + IFNULL(b.TOTAL_PRICE, 0) TOTAL_ARRI_MONEY, + 0 PUBLISH_TOTAL_COUNT, + 0 PROJECT_COUNT, + 0 COMMON_COUNT, + 0 STOCK_COUNT, + 0 MERGER_COUNT, + 0 INDUSTRY_COUNT, + 0 BRAND_COUNT + FROM + ( + SELECT + u.USER_ID, + u.REAL_NAME, + u.ORG_PARENT_ID, + o.ORG_NAME, + u.ORG_ID + FROM + SE_USER u + INNER JOIN SE_ORGANIZ o ON u.ORG_PARENT_ID = o.ORG_ID + WHERE + u.`STATUS` = 1 + AND u.`LEVEL` IN (1, 2, 3) + AND o.PARENT_ID <> 0 + ) a + LEFT JOIN SE_ORGANIZ f ON a.ORG_ID = f.ORG_ID + LEFT JOIN ( + SELECT + CUST_MGR_ID, + COUNT(CONT_ID) CONT_COUNT, + SUM(TOTAL_PRICE) TOTAL_PRICE + FROM + SE_CONTRACT + WHERE + DATE_FORMAT(CREATE_TIME, '%Y-%m-%d') = '2012-06-08' + GROUP BY + CUST_MGR_ID + ) d ON a.USER_ID = d.CUST_MGR_ID + LEFT JOIN ( + SELECT + CUST_MGR_ID, + COUNT(CONT_ID) CONT_COUNT, + SUM(TOTAL_PRICE) TOTAL_PRICE + FROM + SE_CONTRACT + WHERE + (STATUS = 6 OR STATUS = 10) + AND DATE_FORMAT(CREATE_TIME, '%Y-%m-%d') = '2012-06-08' + GROUP BY + CUST_MGR_ID + ) c ON a.USER_ID = c.CUST_MGR_ID + LEFT JOIN ( + SELECT + c.CUST_MGR_ID, + COUNT(c.CONT_ID) CONT_COUNT, + SUM(c.TOTAL_PRICE) TOTAL_PRICE + FROM + SE_CONTRACT c + INNER JOIN SE_CONT_AUDIT a ON c.CONT_ID = a.CONT_ID + WHERE + (c. STATUS = 6 OR c. STATUS = 10) + AND a.IS_PASS = 1 + AND DATE_FORMAT(a.AUDIT_TIME, '%Y-%m-%d') = '2012-06-08' + GROUP BY + c.CUST_MGR_ID + ) b ON a.USER_ID = b.CUST_MGR_ID + ORDER BY + a.ORG_PARENT_ID, + a.USER_ID + """ + + print(sql_parse.extract_tables(sql)) + print(sql_parse.extract_column_names(sql)) + print(sql_parse.extract_sql_keyword("use abc")) + print(sql_parse.extract_sql_keyword(sql)) + diff --git a/utils/sqlparse/patch.py b/utils/sqlparse/patch.py new file mode 100644 index 00000000..63fd4ef1 --- /dev/null +++ b/utils/sqlparse/patch.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +# _#_ coding:utf-8 _*_ +from sqlparse import tokens + +KEYWORDS = { + 'ADD': tokens.Keyword, + 'AFTER': tokens.Keyword, + 'ALIAS': tokens.Keyword, + 'ALL': tokens.Keyword, + 'ALTER': tokens.Keyword.DDL, + 'AND': tokens.Keyword, + 'ANY': tokens.Keyword, + 'ARRAY': tokens.Keyword, + 'AS': tokens.Keyword, + 'ASC': tokens.Keyword.Order, + 'ATTACH': tokens.Keyword.DDL, + 'BETWEEN': tokens.Keyword, + 'BY': tokens.Keyword, + 'CASE': tokens.Keyword, + 'CAST': tokens.Keyword, + 'CHECK': tokens.Keyword.DDL, + 'COLUMN': tokens.Keyword, + 'COPY': tokens.Keyword, + 'CREATE': tokens.Keyword.DDL, + 'DATABASE': tokens.Keyword, + 'STATUS': tokens.Keyword, + 'VARIABLES': tokens.Keyword, + 'DATABASES': tokens.Keyword, + 'INDEX': tokens.Keyword, + # 'DEFAULT': tokens.Keyword, + 'DELETE': tokens.Keyword.DML, + 'DESC': tokens.Keyword, + 'DESCRIBE': tokens.Keyword.DDL, + 'DETACH': tokens.Keyword.DDL, + 'DISTINCT': tokens.Keyword, + 'DROP': tokens.Keyword.DDL, + 'ELSE': tokens.Keyword, + 'END': tokens.Keyword, + 'ENGINE': tokens.Keyword, + 'EXISTS': tokens.Keyword, + 'FALSE': tokens.Keyword, + 'FETCH': tokens.Keyword, + 'FINAL': tokens.Keyword, + 'FIRST': tokens.Keyword, + 'FORMAT': tokens.Keyword, + 'FREEZE': tokens.Keyword, + 'FROM': tokens.Keyword, + 'FULL': tokens.Keyword, + 'GLOBAL': tokens.Keyword, + 'GROUP': tokens.Keyword, + 'HAVING': tokens.Keyword, + 'IF': tokens.Keyword, + 'IN': tokens.Keyword, + 'INNER': tokens.Keyword, + 'INSERT': tokens.Keyword.DML, + 'INTO': tokens.Keyword, + 'IS': tokens.Keyword, + 'JOIN': tokens.Keyword, + 'KEY': tokens.Keyword, + 'LEFT': tokens.Keyword, + 'LIKE': tokens.Keyword, + 'LIMIT': tokens.Keyword, + 'MATERIALIZED': tokens.Keyword, + 'MODIFY': tokens.Keyword, + 'NAME': tokens.Keyword, + 'NOT': tokens.Keyword, + 'OF': tokens.Keyword, + 'ON': tokens.Keyword, + 'OPTIMIZE': tokens.Keyword.DDL, + 'OR': tokens.Keyword, + 'ORDER': tokens.Keyword, + 'OUTER': tokens.Keyword, + 'OUTFILE': tokens.Keyword, + 'PART': tokens.Keyword, + 'PARTITION': tokens.Keyword, + 'POPULATE': tokens.Keyword, + 'PREWHERE': tokens.Keyword, + 'PRIMARY': tokens.Keyword, + 'RENAME': tokens.Keyword.DDL, + 'RESHARD': tokens.Keyword, + 'RIGHT': tokens.Keyword, + 'SELECT': tokens.Keyword.DML, + 'SET': tokens.Keyword.DDL, + 'SETTINGS': tokens.Keyword, + 'SHOW': tokens.Keyword.DDL, + 'TABLE': tokens.Keyword, + 'TABLES': tokens.Keyword, + 'TEMPORARY': tokens.Keyword, + 'THEN': tokens.Keyword, + 'TO': tokens.Keyword, + 'TOTALS': tokens.Keyword, + 'UNION': tokens.Keyword, + 'UNREPLICATED': tokens.Keyword, + 'USE': tokens.Keyword.DDL, + 'USING': tokens.Keyword, + 'VALUES': tokens.Keyword, + 'VIEW': tokens.Keyword, + 'WHEN': tokens.Keyword, + 'WHERE': tokens.Keyword, + 'WITH': tokens.Keyword, +}