diff --git a/internal/lineage/sqlparser.go b/internal/lineage/sqlparser.go index 28d533e..aeb3ba7 100644 --- a/internal/lineage/sqlparser.go +++ b/internal/lineage/sqlparser.go @@ -96,11 +96,13 @@ func (o *Op) GetID() string { return o.SchemaName + "." + o.ProcName } -func ParseUDF(sqlTree *depgraph.Graph, plpgsql string) error { +func ParseUDF(plpgsql string) (*depgraph.Graph, error) { + + sqlTree := depgraph.New() raw, err := pg_query.ParsePlPgSqlToJSON(plpgsql) if err != nil { - return err + return nil, err } // log.Debugf("pg_query.ParsePlPgSqlToJSON: %s", raw) @@ -123,7 +125,7 @@ func ParseUDF(sqlTree *depgraph.Graph, plpgsql string) error { }) } - return nil + return nil, err } func parseUDFOperator(sqlTree *depgraph.Graph, operator, plan string) error { @@ -153,6 +155,16 @@ func parseUDFOperator(sqlTree *depgraph.Graph, operator, plan string) error { return nil } +func Parse(sql string) (*depgraph.Graph, error) { + sqlTree := depgraph.New() + + if err := ParseSQL(sqlTree, sql); err != nil { + return nil, err + } + + return sqlTree, nil +} + func ParseSQL(sqlTree *depgraph.Graph, sql string) error { raw, err := pg_query.ParseToJSON(sql) diff --git a/main.go b/main.go index 7af5d90..1f74ff9 100644 --- a/main.go +++ b/main.go @@ -160,8 +160,6 @@ func generateTableJoinRelation(qs *QueryStore, ds *DataSource, driver neo4j.Driv func generateTableLineage(qs *QueryStore, ds *DataSource, driver neo4j.Driver) { // 一个 UDF 一张图 - sqlTree := depgraph.New(ds.Alias) - udf, err := IdentifyFuncCall(qs.Query) if err != nil { return @@ -171,7 +169,8 @@ func generateTableLineage(qs *QueryStore, ds *DataSource, driver neo4j.Driver) { // ProcName: "func_insert_fact_sn_info_f6", // SchemaName: "dw", // } - if err := HandleUDF4Lineage(sqlTree, ds.DB, udf); err != nil { + sqlTree, err := HandleUDF4Lineage(ds.DB, udf) + if err != nil { log.Errorf("HandleUDF %+v, err: %s", udf, err) return } @@ -181,7 +180,9 @@ func generateTableLineage(qs *QueryStore, ds *DataSource, driver neo4j.Driver) { log.Debugf("UDF Graph %d: %s\n", i, strings.Join(layer, ", ")) } - // TODO:完善辅助信息 + // 设置所属命名空间,避免节点冲突 + sqlTree.SetNamespace(ds.Alias) + // 完善辅助信息 if err := lineage.CreateGraph(driver, sqlTree.ShrinkGraph(), udf); err != nil { log.Errorf("UDF CreateGraph err: %s ", err) @@ -216,18 +217,18 @@ func IdentifyFuncCall(sql string) (*lineage.Op, error) { } // 解析函数调用 -func HandleUDF4Lineage(sqlTree *depgraph.Graph, db *sql.DB, udf *lineage.Op) error { +func HandleUDF4Lineage(db *sql.DB, udf *lineage.Op) (*depgraph.Graph, error) { log.Infof("HandleUDF: %s.%s", udf.SchemaName, udf.ProcName) // 排除系统函数的干扰 e.g. select now() if udf.SchemaName == "" || udf.SchemaName == "pg_catalog" { - return fmt.Errorf("UDF %s is system function", udf.ProcName) + return nil, fmt.Errorf("UDF %s is system function", udf.ProcName) } definition, err := GetUDFDefinition(db, udf) if err != nil { log.Errorf("GetUDFDefinition err: %s", err) - return err + return nil, err } // 字符串过滤,后期 pg_query 支持 set 了,可以去掉 @@ -235,12 +236,13 @@ func HandleUDF4Lineage(sqlTree *depgraph.Graph, db *sql.DB, udf *lineage.Op) err plpgsql := filterUnhandledCommands(definition) // log.Debug("plpgsql: ", plpgsql) - if err := lineage.ParseUDF(sqlTree, plpgsql); err != nil { + sqlTree, err := lineage.ParseUDF(plpgsql) + if err != nil { log.Errorf("ParseUDF %+v, err: %s", udf, err) - return err + return nil, err } - return nil + return sqlTree, nil } func HandleUDF4ERD(db *sql.DB, udf *lineage.Op) (map[string]*erd.RelationShip, error) { diff --git a/pkg/depgraph/depgraph.go b/pkg/depgraph/depgraph.go index 496d5bb..16bab42 100644 --- a/pkg/depgraph/depgraph.go +++ b/pkg/depgraph/depgraph.go @@ -34,12 +34,12 @@ type Graph struct { namespace string } -func New(namespace string) *Graph { +func New() *Graph { return &Graph{ dependencies: make(depmap), dependents: make(depmap), nodes: make(nodeset), - namespace: namespace, + namespace: "default", } } @@ -55,6 +55,10 @@ func (g *Graph) GetNamespace() string { return g.namespace } +func (g *Graph) SetNamespace(namespace string) { + g.namespace = namespace +} + // Add nodes and relationships func (g *Graph) DependOn(child Node, parent Node) error { if child.GetID() == parent.GetID() {