Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(qd): 🦄 优化代码和 logger 格式和静态类型 #496

Merged
merged 12 commits into from
Feb 3, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[flake8]
max-line-length = 120
ignore = E203, E266, E501, W503
exclude = .git, __pycache__, venv, dist, build
5 changes: 5 additions & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ tzdata = "*"
[dev-packages]
#ddddocr = "*"
#pycurl = {version = "*", markers="sys_platform != 'win32'"}
types-croniter = "*"
types-requests = "*"
types-python-dateutil = "*"
sqlalchemy2-stubs = "*"
types-aiofiles = "*"

[requires]
python_version = "3.11"
1,078 changes: 612 additions & 466 deletions Pipfile.lock

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Author: Binux<[email protected]>
# http://binux.me
# Created on 2014-08-09 11:39:25
# pylint: disable=broad-exception-raised

import sqlite3

Expand Down Expand Up @@ -127,30 +128,29 @@ def new(self, userid, maindb):
`notepadid` INTEGER NOT NULL ,
`content` TEXT NULL
);
''' )
''')

# 获取数据库信息
userid = int(userid)
user = maindb.db.user.get(id=userid, fields=('id', 'email', 'email_verified', 'password', 'password_md5', 'userkey', 'nickname', 'role', 'ctime', 'mtime', 'atime', 'cip',
'mip', 'aip', 'skey', 'barkurl', 'wxpusher', 'noticeflg', 'logtime', 'status', 'notepad', 'diypusher', 'qywx_token', 'tg_token', 'dingding_token', 'qywx_webhook', 'push_batch'))
userkey = maindb.db.user.__getuserkey(user['env'])
# user = maindb.db.user.get(id=userid, fields=('id', 'email', 'email_verified', 'password', 'password_md5', 'userkey', 'nickname', 'role', 'ctime', 'mtime', 'atime', 'cip',
# 'mip', 'aip', 'skey', 'barkurl', 'wxpusher', 'noticeflg', 'logtime', 'status', 'notepad', 'diypusher', 'qywx_token', 'tg_token', 'dingding_token', 'qywx_webhook', 'push_batch'))
# userkey = maindb.db.user.__getuserkey(user['env'])
tpls = []
for tpl in maindb.db.tpl.list(fields=('id', 'userid', 'siteurl', 'sitename', 'banner', 'disabled', 'public', 'lock', 'fork', 'har', 'tpl', 'variables', 'interval', 'note', 'success_count', 'failed_count', 'last_success', 'ctime', 'mtime', 'atime', 'tplurl', 'updateable', '_groups', 'init_env'), limit=None):
if tpl['userid'] == userid:
tpls.append(tpl)
tasks = []
tasklogs = []
for task in maindb.db.task.list(userid, fields=('id', 'tplid', 'userid', 'note', 'disabled', 'init_env', 'env', 'session', 'retry_count', 'retry_interval', 'last_success', 'success_count',
'failed_count', 'last_failed', 'next', 'last_failed_count', 'ctime', 'mtime', 'ontimeflg', 'ontime', '_groups', 'pushsw', 'newontime'), limit=None):
'failed_count', 'last_failed', 'next', 'last_failed_count', 'ctime', 'mtime', 'ontimeflg', 'ontime', '_groups', 'pushsw', 'newontime'), limit=None):
if task['userid'] == userid:
tasks.append(task)
for tasklog in maindb.db.tasklog.list(taskid = task['id'], fields=('id', "taskid", "success", "ctime", "msg")):
for tasklog in maindb.db.tasklog.list(taskid=task['id'], fields=('id', "taskid", "success", "ctime", "msg")):
tasklogs.append(tasklog)

c.close()
conn.close()


except Exception as e:
raise Exception("backup database error")
raise Exception("backup database error") from e
print("OK")
12 changes: 8 additions & 4 deletions chrole.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,20 @@
"""

import asyncio
import logging
import sys

import db

logger = logging.getLogger(__name__)


def usage():
print('Usage: python3 %s <email> [role]' % sys.argv[0])
print('Example: python3 %s [email protected] admin' % sys.argv[0])
print(f'Usage: python3 {sys.argv[0]} <email> [role]')
print(f'Example: python3 {sys.argv[0]} [email protected] admin')
sys.exit(1)


async def main():
email = sys.argv[1]
role = sys.argv[2] if len(sys.argv) == 3 else ''
Expand All @@ -31,9 +35,9 @@ async def main():
sys.exit(1)
rowcount = await userdb.mod(user['id'], role=role)
if rowcount >= 1:
print("role of {} changed to {}".format(email, role or '[empty]'))
logger.info("role of %s changed to %s", email, role or '[empty]')
else:
print("role of {} not changed".format(email))
logger.warning("role of %s not changed", email)


if __name__ == '__main__':
Expand Down
141 changes: 79 additions & 62 deletions config.py

Large diffs are not rendered by default.

23 changes: 9 additions & 14 deletions db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,17 @@
# http://binux.me
# Created on 2014-08-08 20:28:15

import os
import sys

from db.basedb import AlchemyMixin

sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
from .notepad import Notepad
from .pubtpl import Pubtpl
from .push_request import PushRequest
from .redisdb import RedisDB
from .site import Site
from .task import Task
from .tasklog import Tasklog
from .tpl import Tpl
from .user import User
from db.notepad import Notepad
from db.pubtpl import Pubtpl
from db.push_request import PushRequest
from db.redisdb import RedisDB
from db.site import Site
from db.task import Task
from db.tasklog import Tasklog
from db.tpl import Tpl
from db.user import User


class DB(AlchemyMixin):
Expand All @@ -33,4 +29,3 @@ def __init__(self) -> None:
self.site = Site()
self.pubtpl = Pubtpl()
self.notepad = Notepad()

108 changes: 53 additions & 55 deletions db/basedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,75 +6,76 @@
# Created on 2012-08-30 17:43:49

import contextlib
import logging
from asyncio import current_task
from typing import Tuple
from typing import AsyncIterator, Optional, Union

from sqlalchemy import text
from sqlalchemy.dialects.mysql import Insert
from sqlalchemy.engine import CursorResult, Result, ScalarResult
from sqlalchemy.engine import Result, Row
from sqlalchemy.ext.asyncio import (AsyncSession, async_scoped_session,
create_async_engine)
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.sql import Delete, Select, Update
from sqlalchemy.sql.elements import TextClause

import config
from libs.log import Log

if config.db_type == 'mysql':
host=config.mysql.host
port=config.mysql.port
database=config.mysql.database
user=config.mysql.user
passwd=config.mysql.passwd
auth_plugin=config.mysql.auth_plugin
host = config.mysql.host
port = config.mysql.port
database = config.mysql.database
user = config.mysql.user
passwd = config.mysql.passwd
auth_plugin = config.mysql.auth_plugin
engine_url = f"mysql+aiomysql://{user}:{passwd}@{host}:{port}/{database}?auth_plugin={auth_plugin}"
engine = create_async_engine(engine_url,
logging_name = config.sqlalchemy.logging_name,
pool_size = config.sqlalchemy.pool_size,
max_overflow = config.sqlalchemy.max_overflow,
pool_logging_name = config.sqlalchemy.pool_logging_name,
pool_pre_ping = config.sqlalchemy.pool_pre_ping,
pool_recycle = config.sqlalchemy.pool_recycle,
pool_timeout = config.sqlalchemy.pool_timeout,
pool_use_lifo = config.sqlalchemy.pool_use_lifo)
logging_name=config.sqlalchemy.logging_name,
pool_size=config.sqlalchemy.pool_size,
max_overflow=config.sqlalchemy.max_overflow,
pool_logging_name=config.sqlalchemy.pool_logging_name,
pool_pre_ping=config.sqlalchemy.pool_pre_ping,
pool_recycle=config.sqlalchemy.pool_recycle,
pool_timeout=config.sqlalchemy.pool_timeout,
pool_use_lifo=config.sqlalchemy.pool_use_lifo)
elif config.db_type == 'sqlite3':
engine_url = f"sqlite+aiosqlite:///{config.sqlite3.path}"
engine = create_async_engine(engine_url,
logging_name = config.sqlalchemy.logging_name,
pool_logging_name = config.sqlalchemy.pool_logging_name,
pool_pre_ping = config.sqlalchemy.pool_pre_ping,
pool_recycle = config.sqlalchemy.pool_recycle )
logging_name=config.sqlalchemy.logging_name,
pool_logging_name=config.sqlalchemy.pool_logging_name,
pool_pre_ping=config.sqlalchemy.pool_pre_ping,
pool_recycle=config.sqlalchemy.pool_recycle)
Log('aiosqlite',
logger_level=config.sqlalchemy.pool_logging_level,
channel_level=config.sqlalchemy.pool_logging_level).getlogger()
else:
raise Exception('db_type must be mysql or sqlite3')
logger_DB = Log('sqlalchemy',
logger_db = Log('sqlalchemy',
logger_level=config.sqlalchemy.logging_level,
channel_level=config.sqlalchemy.logging_level).getlogger()
logger_DB_Engine = Log(engine.engine.logger,
logger_level=config.sqlalchemy.logging_level,
channel_level=config.sqlalchemy.logging_level).getlogger()
if hasattr(engine.pool.logger, 'logger'):
logger_DB_POOL = Log(engine.pool.logger.logger,
logger_level=config.sqlalchemy.pool_logging_level,
channel_level=config.sqlalchemy.pool_logging_level).getlogger()
else:
logger_DB_POOL = Log(engine.pool.logger,
logger_level=config.sqlalchemy.pool_logging_level,
channel_level=config.sqlalchemy.pool_logging_level).getlogger()
logger_db_engine = Log(getattr(engine.sync_engine, 'logger', f'sqlalchemy.engine.Engine.{config.sqlalchemy.logging_name}'),
logger_level=config.sqlalchemy.logging_level,
channel_level=config.sqlalchemy.logging_level).getlogger()
if hasattr(engine.sync_engine.pool, 'logger'):
if hasattr(getattr(engine.sync_engine.pool, 'logger'), 'logger'):
logger_db_pool = Log(engine.sync_engine.pool.logger.logger,
logger_level=config.sqlalchemy.pool_logging_level,
channel_level=config.sqlalchemy.pool_logging_level).getlogger()
else:
logger_db_pool = Log(engine.sync_engine.pool.logger,
logger_level=config.sqlalchemy.pool_logging_level,
channel_level=config.sqlalchemy.pool_logging_level).getlogger()
async_session = async_scoped_session(sessionmaker(engine, class_=AsyncSession, expire_on_commit=False),
scopefunc=current_task)
BaseDB = declarative_base(bind=engine, name="BaseDB")


class AlchemyMixin:
@property
def sql_session(self) -> AsyncSession:
return async_session()

@contextlib.asynccontextmanager
async def transaction(self, sql_session:AsyncSession=None):
async def transaction(self, sql_session: Optional[AsyncSession] = None) -> AsyncIterator[AsyncSession]:
if sql_session is None:
async with self.sql_session as sql_session:
# deepcode ignore AttributeLoadOnNone: sql_session is not None
Expand All @@ -86,26 +87,24 @@ async def transaction(self, sql_session:AsyncSession=None):
else:
yield sql_session

async def _execute(self, text:Tuple[str,text], sql_session:AsyncSession=None):
async def _execute(self, text: Union[str, TextClause], sql_session: Optional[AsyncSession] = None):
async with self.transaction(sql_session) as sql_session:
if isinstance(text, str):
text = text.replace(':', r'\:')
text = text.replace(':', r'\:') # 如果text原本是个字符串,则转义冒号
text = TextClause(text) # 将其转换为TextClause对象
result = await sql_session.execute(text)
return result

async def _get(self, stmt: Select, one_or_none=False, first=False, all=True, sql_session:AsyncSession=None):
async def _get(self, stmt: Select, one_or_none=False, first=False, sql_session: Optional[AsyncSession] = None):
async with self.transaction(sql_session) as sql_session:
result: Result = await sql_session.execute(stmt)
if one_or_none:
return result.scalar_one_or_none()
elif first:
if first:
return result.first()
elif all:
return result.all()
else:
return result
return result.all()

async def _insert(self, instance, many=False, sql_session:AsyncSession=None):
async def _insert(self, instance, many=False, sql_session: Optional[AsyncSession] = None):
async with self.transaction(sql_session) as sql_session:
if many:
sql_session.add_all(instance)
Expand All @@ -114,27 +113,26 @@ async def _insert(self, instance, many=False, sql_session:AsyncSession=None):
await sql_session.flush()
return instance.id

async def _update(self, stmt: Update, sql_session:AsyncSession=None):
async def _update(self, stmt: Update, sql_session: Optional[AsyncSession] = None):
async with self.transaction(sql_session) as sql_session:
result: Result = await sql_session.execute(stmt)
return result.rowcount
return result.rowcount if hasattr(result, 'rowcount') else -1

async def _insert_or_update(self, insert_stmt: Insert, sql_session:AsyncSession=None, **kwargs) -> int:
async def _insert_or_update(self, insert_stmt: Insert, sql_session: Optional[AsyncSession] = None, **kwargs) -> int:
async with self.transaction(sql_session) as sql_session:
on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update(**kwargs)
result: CursorResult = await sql_session.execute(on_duplicate_key_stmt)
return result.lastrowid
insert_stmt.on_duplicate_key_update(**kwargs)
result: Result = await sql_session.execute(insert_stmt)
return result.lastrowid if hasattr(result, 'lastrowid') else -1

async def _delete(self, stmt: Delete, sql_session:AsyncSession=None):
async def _delete(self, stmt: Delete, sql_session: Optional[AsyncSession] = None):
async with self.transaction(sql_session) as sql_session:
result: Result = await sql_session.execute(stmt)
return result.rowcount
return result.rowcount if hasattr(result, 'rowcount') else -1

@staticmethod
def to_dict(result,fields=None):
def to_dict(result: Row, fields=None):
if result is None:
return result
if fields is None:
return {c.name: getattr(result[0], c.name) for c in result[0].__table__.columns}
else:
return dict(result._mapping)
return dict(result._mapping) # pylint: disable=protected-access
Loading
Loading