Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 规则审计-策略新增/编辑-后端接口 --story=121513458 #508

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/backend/core/sql/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@
)


class Table(BaseModel):
"""
"""

table_name: str # 表名
alias: Optional[str] = None # 别名


class Field(BaseModel):
"""
字段
Expand Down Expand Up @@ -58,8 +67,8 @@ class JoinTable(BaseModel):

join_type: JoinType # 连接类型
link_fields: List[LinkField] # 连接字段
left_table: str # 左表
right_table: str # 右表
left_table: Table # 左表
right_table: Table # 右表


class Condition(BaseModel):
Expand Down Expand Up @@ -107,7 +116,7 @@ class SqlConfig(BaseModel):
"""

select_fields: List[Field] # 作为 sql 的列
from_table: str = "" # 主表
from_table: Optional[Table] = None # 主表
join_tables: Optional[List[JoinTable]] = None # 联表
where: Optional[WhereCondition] = None # 筛选条件
group_by: List[Field] = PydanticField(default_factory=list) # 分组条件;如果未指定但有聚合函数,则会自动添加 group by 条件
Expand Down
45 changes: 27 additions & 18 deletions src/backend/core/sql/sql_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
We undertake not to change the open source license (MIT license) applicable
to the current version of the project delivered to anyone in the future.
"""
from typing import Dict, Optional
from typing import Dict, Optional, Union

from pypika import Field as PypikaField
from pypika import Table
Expand All @@ -30,7 +30,9 @@
UnsupportedJoinTypeError,
UnsupportedOperatorError,
)
from core.sql.model import Condition, Field, SqlConfig, WhereCondition
from core.sql.model import Condition, Field, SqlConfig
from core.sql.model import Table as SqlTable
from core.sql.model import WhereCondition


class SQLGenerator:
Expand All @@ -48,22 +50,29 @@ def __init__(self, query_builder: QueryBuilder, config: SqlConfig):

def _register_tables(self):
"""注册所有有效的表名"""
valid_tables = set()
register_tables = {}

# 添加主表到注册表
if self.config.from_table:
valid_tables.add(self.config.from_table)
if self.config.join_tables:
for join_table in self.config.join_tables:
valid_tables.add(join_table.left_table)
valid_tables.add(join_table.right_table)
# 记录所有表名对应的 Table 对象
for table_name in valid_tables:
self.table_map[table_name] = Table(table_name).as_(table_name)

def _get_table(self, table_name: str) -> Table:
alias = self.config.from_table.alias or self.config.from_table.table_name
register_tables[alias] = self.config.from_table

# 添加连接表到注册表
for join_table in self.config.join_tables or []:
for table in [join_table.left_table, join_table.right_table]:
alias = table.alias or table.table_name
register_tables[alias] = table

# 更新 table_map 映射
self.table_map.update({alias: Table(table.table_name).as_(alias) for alias, table in register_tables.items()})

def _get_table(self, table: Union[str, SqlTable]) -> Table:
"""根据表名获取 Table 对象"""
if table_name not in self.table_map:
raise TableNotRegisteredError(table_name)
return self.table_map[table_name]
if isinstance(table, SqlTable):
table = table.alias or table.table_name
if table not in self.table_map:
raise TableNotRegisteredError(table)
return self.table_map[table]

def _get_pypika_field(self, field: Field) -> PypikaField:
"""根据 Field 获取 PyPika 字段"""
Expand All @@ -85,8 +94,8 @@ def _build_from(self, query: QueryBuilder) -> QueryBuilder:
"""添加 FROM 子句"""
if not (self.config.from_table or self.config.join_tables):
raise MissingFromOrJoinError()
if self.config.from_table:
query = query.from_(self._get_table(self.config.from_table))
from_table = self.config.join_tables[0].left_table if self.config.join_tables else self.config.from_table
query = query.from_(self._get_table(from_table))
if self.config.join_tables:
query = self._build_join(self.config.from_table, query)
return query
Expand Down
43 changes: 27 additions & 16 deletions src/backend/services/web/strategy_v2/handlers/rule_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
JoinTable,
LinkField,
SqlConfig,
Table,
WhereCondition,
)
from core.sql.sql_builder import SQLGenerator
Expand Down Expand Up @@ -130,13 +131,13 @@ def trans_field(self, field_json: dict) -> Field:
aggregate=field_json.get("aggregate"),
)

def build_system_ids_condition(self, table_rt_id: str, system_ids: list) -> WhereCondition:
def build_system_ids_condition(self, table_name: str, system_ids: list) -> WhereCondition:
"""
根据给定的表 rt_id 及 system_ids 列表,构建一个 AND 条件,用于拼接到最终的 WHERE 中。
"""
system_condition = Condition(
field=Field(
table=table_rt_id,
table=table_name,
raw_name=SYSTEM_ID.field_name,
display_name=SYSTEM_ID.alias_name,
field_type=SYSTEM_ID.field_type,
Expand All @@ -146,19 +147,20 @@ def build_system_ids_condition(self, table_rt_id: str, system_ids: list) -> Wher
)
return WhereCondition(connector=FilterConnector.AND, condition=system_condition)

def build_single_table_config(self, data_source: dict) -> (str, list, dict):
def build_single_table_config(self, data_source: dict) -> (Table, list, dict):
"""
处理单表场景,返回:
- from_table: str
- from_table: 主表
- join_tables: 空列表 (单表无 join)
- tables_with_system_ids: {rt_id: [system_ids]}
- tables_with_system_ids: {display_name: [system_ids]}
"""
from_table = data_source["rt_id"]
display_name = data_source.get("display_name", from_table)
system_ids = data_source.get("system_ids", [])
tables_with_system_ids = {from_table: system_ids}
return from_table, [], tables_with_system_ids
tables_with_system_ids = {display_name: system_ids}
return Table(table_name=from_table, alias=display_name), [], tables_with_system_ids

def build_link_table_config(self, data_source: dict) -> (str, list, dict):
def build_link_table_config(self, data_source: dict) -> (Table, list, dict):
"""
处理联表场景,从 link_table 配置中构建:
- from_table: 主表
Expand All @@ -177,19 +179,28 @@ def build_link_table_config(self, data_source: dict) -> (str, list, dict):

# 确定主表 (from_table)
first_link = links[0]
from_table = first_link["left_table"]["rt_id"]
from_table = first_link["left_table"]
_from_table = Table(table_name=from_table["rt_id"], alias=from_table.get("display_name", from_table["rt_id"]))
join_tables = []
tables_with_system_ids = {}

for lk in links:
left_table = lk["left_table"]
_left_table = Table(
table_name=left_table["rt_id"],
alias=left_table.get("display_name", left_table["rt_id"]),
)
right_table = lk["right_table"]
_right_table = Table(
table_name=right_table["rt_id"],
alias=right_table.get("display_name", right_table["rt_id"]),
)

# 如果 left_table 或 right_table 是 EVENT_LOG,则将它们的 system_ids 收集起来
if left_table["table_type"] == LinkTableTableType.EVENT_LOG:
tables_with_system_ids[left_table["rt_id"]] = left_table.get("system_ids", [])
tables_with_system_ids[_left_table.alias] = left_table.get("system_ids", [])
if right_table["table_type"] == LinkTableTableType.EVENT_LOG:
tables_with_system_ids[right_table["rt_id"]] = right_table.get("system_ids", [])
tables_with_system_ids[_right_table.alias] = right_table.get("system_ids", [])

# link_fields
link_fields_list = [
Expand All @@ -201,12 +212,12 @@ def build_link_table_config(self, data_source: dict) -> (str, list, dict):
JoinTable(
join_type=lk["join_type"],
link_fields=link_fields_list,
left_table=left_table["rt_id"],
right_table=right_table["rt_id"],
left_table=_left_table,
right_table=_right_table,
)
)

return from_table, join_tables, tables_with_system_ids
return _from_table, join_tables, tables_with_system_ids

def format(self, config_json: dict) -> SqlConfig:
"""
Expand All @@ -233,10 +244,10 @@ def format(self, config_json: dict) -> SqlConfig:
conditions_to_merge.append(front_where)

# Step D. 为每个包含 system_ids 的表构建条件
for rt_id, system_ids in tables_with_system_ids.items():
for table_name, system_ids in tables_with_system_ids.items():
if not system_ids: # 若没有 system_ids,可根据需要决定是否忽略或抛异常
continue
system_ids_where = self.build_system_ids_condition(rt_id, system_ids)
system_ids_where = self.build_system_ids_condition(table_name, system_ids)
conditions_to_merge.append(system_ids_where)

# Step E. 合并所有条件
Expand Down
12 changes: 12 additions & 0 deletions src/backend/services/web/strategy_v2/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,7 @@ class LinkTableConfigTableSerializer(serializers.Serializer):

rt_id = serializers.CharField(label=gettext_lazy("Result Table ID"))
table_type = serializers.ChoiceField(label=gettext_lazy("Table Type"), choices=LinkTableTableType.choices)
display_name = serializers.CharField(label=gettext_lazy("Display Name"), required=False)
system_ids = serializers.ListField(
label=gettext_lazy("System IDs"), child=serializers.CharField(max_length=64), required=False
)
Expand All @@ -724,6 +725,9 @@ def validate(self, attrs):
attrs = super().validate(attrs)
if attrs["table_type"] == LinkTableTableType.EVENT_LOG and not attrs.get("system_ids"):
raise serializers.ValidationError(gettext("System IDs is required"))
# display_name 默认值为 rt_id
if not attrs.get("display_name"):
attrs["display_name"] = attrs["rt_id"]
return attrs


Expand Down Expand Up @@ -901,6 +905,14 @@ class RuleAuditDataSourceSerializer(serializers.Serializer):
label=gettext_lazy("System ID"), child=serializers.CharField(), required=False, allow_empty=True
)
link_table = RuleAuditLinkTableSerializer(label=gettext_lazy("Link Table"), required=False)
display_name = serializers.CharField(label=gettext_lazy("Display Name"), required=False)

def validate(self, attrs):
attrs = super().validate(attrs)
# display_name 默认值为 rt_id
if attrs.get("rt_id") and not attrs.get("display_name"):
attrs['display_name'] = attrs['rt_id']
return attrs


class RuleAuditConditionSerializer(serializers.Serializer):
Expand Down
Loading
Loading