Skip to content

Commit

Permalink
chore+fix: further refactor and cleanup; BED-5532 batch operation beh…
Browse files Browse the repository at this point in the history
…avior change to move away from transaction thrashing in pg
  • Loading branch information
zinic committed Mar 5, 2025
1 parent d51591f commit f578d5c
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 95 deletions.
21 changes: 6 additions & 15 deletions packages/go/cypher/models/pgsql/translate/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,11 @@ func expansionColumns() pgsql.RecordShape {
}
}

type Match struct {
Pattern *Pattern
}

type NodeSelect struct {
Frame *Frame
Binding *BoundIdentifier
Select pgsql.Select
Constraint *Constraint
Frame *Frame
Binding *BoundIdentifier
Select pgsql.Select
Constraints pgsql.Expression
}

type Expansion struct {
Expand All @@ -66,19 +62,15 @@ type Expansion struct {
MinDepth models.Optional[int64]
MaxDepth models.Optional[int64]

PrimerProjection []pgsql.SelectItem
PrimerConstraints pgsql.Expression

RecursiveProjection []pgsql.SelectItem
PrimerConstraints pgsql.Expression
RecursiveConstraints pgsql.Expression

LeftNodeJoinCondition pgsql.Expression
ExpansionEdgeConstraints pgsql.Expression
ExpansionNodeConstraints pgsql.Expression
TerminalNodeConstraints pgsql.Expression

Projection []pgsql.SelectItem
Constraints pgsql.Expression
Projection []pgsql.SelectItem
}

type PatternSegment struct {
Expand All @@ -96,7 +88,6 @@ type PatternSegment struct {
RightNodeBound bool
RightNodeConstraints pgsql.Expression
RightNodeJoinCondition pgsql.Expression
Definitions []*BoundIdentifier
Projection []pgsql.SelectItem
}

Expand Down
46 changes: 7 additions & 39 deletions packages/go/cypher/models/pgsql/translate/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,37 +45,6 @@ func (s *Translator) translateNodePatternSegment(nodePattern *cypher.NodePattern
return nil
}

func (s *Translator) translateNodePatternSegmentWithTraversal(currentSegment *PatternSegment) error {
// Note: The order below matters as it will change the order of projections in the resulting translation

// If either of the nodes are not bound at this point then this traversal must materialize them
if !currentSegment.LeftNodeBound {
currentSegment.Definitions = append(currentSegment.Definitions, currentSegment.LeftNode)
}

// Add the edge symbol as part of the definitions that are being materialized by this expansion
currentSegment.Definitions = append(currentSegment.Definitions, currentSegment.Edge)

// If there's an expansion attached to this traversal, ensure that the symbol for the expansion's scope frame
// is part of the definitions that are being materialized by this expansion
if currentSegment.Expansion.Set {
currentSegment.Definitions = append(currentSegment.Definitions, currentSegment.Expansion.Value.PathBinding)
}

// If the right node has not been bound before add it to the pattern part's list of new definitions
if !currentSegment.RightNodeBound {
currentSegment.Definitions = append(currentSegment.Definitions, currentSegment.RightNode)
}

if currentSegment.Expansion.Set {
// Update the data type of the right node so that it reflects that it is now the terminal node of
// an expansion
currentSegment.RightNode.DataType = pgsql.ExpansionTerminalNode
}

return nil
}

func (s *Translator) translateNodePatternToStep(nodePattern *cypher.NodePattern, part *PatternPart, bindingResult BindingResult) error {
currentQueryPart := s.query.CurrentPart()

Expand Down Expand Up @@ -129,7 +98,9 @@ func (s *Translator) translateNodePatternToStep(nodePattern *cypher.NodePattern,
currentStep.RightNodeBound = bindingResult.AlreadyBound

// Finish setting up this traversal step
return s.translateNodePatternSegmentWithTraversal(currentStep)
if currentStep.Expansion.Set {
currentStep.RightNode.DataType = pgsql.ExpansionTerminalNode
}
}
} else {
return s.translateNodePatternSegment(nodePattern, part, bindingResult)
Expand All @@ -141,7 +112,10 @@ func (s *Translator) translateNodePatternToStep(nodePattern *cypher.NodePattern,
func (s *Translator) buildNodePattern(part *PatternPart) error {
var (
partFrame = part.NodeSelect.Frame
nextSelect pgsql.Select
nextSelect = pgsql.Select{
Projection: part.NodeSelect.Select.Projection,
Where: part.NodeSelect.Constraints,
}
)

// The current query part may not have a frame associated with it if is a single part query component
Expand All @@ -153,12 +127,6 @@ func (s *Translator) buildNodePattern(part *PatternPart) error {
})
}

nextSelect.Projection = part.NodeSelect.Select.Projection

if part.NodeSelect.Constraint != nil {
nextSelect.Where = part.NodeSelect.Constraint.Expression
}

nextSelect.From = append(nextSelect.From, pgsql.FromClause{
Source: pgsql.TableReference{
Name: pgsql.CompoundIdentifier{pgsql.TableNode},
Expand Down
2 changes: 1 addition & 1 deletion packages/go/cypher/models/pgsql/translate/traversal.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ func (s *Translator) translateNonTraversalPatternPart(part *PatternPart) error {
} else if err := RewriteFrameBindings(s.scope, constraint.Expression); err != nil {
return err
} else {
part.NodeSelect.Constraint = constraint
part.NodeSelect.Constraints = constraint.Expression
}

if boundProjections, err := buildVisibleProjections(s.scope); err != nil {
Expand Down
18 changes: 9 additions & 9 deletions packages/go/dawgs/drivers/pg/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (s *Int2ArrayEncoder) Encode(values []int16) string {

type batch struct {
ctx context.Context
innerTransaction *transaction
innerTransaction *txWrapper
schemaManager *SchemaManager
nodeDeletionBuffer []graph.ID
relationshipDeletionBuffer []graph.ID
Expand All @@ -67,7 +67,7 @@ type batch struct {
}

func newBatch(ctx context.Context, conn *pgxpool.Conn, schemaManager *SchemaManager, cfg *Config) (*batch, error) {
if tx, err := newTransaction(ctx, conn, schemaManager, cfg); err != nil {
if tx, err := newTXWrapper(ctx, conn, schemaManager, cfg, false); err != nil {
return nil, err
} else {
return &batch{
Expand Down Expand Up @@ -106,7 +106,7 @@ func (s *batch) UpdateNodeBy(update graph.NodeUpdate) error {
}

func (s *batch) flushNodeDeleteBuffer() error {
if _, err := s.innerTransaction.tx.Exec(s.ctx, deleteNodeWithIDStatement, s.nodeDeletionBuffer); err != nil {
if _, err := s.innerTransaction.conn.Exec(s.ctx, deleteNodeWithIDStatement, s.nodeDeletionBuffer); err != nil {
return err
}

Expand All @@ -115,7 +115,7 @@ func (s *batch) flushNodeDeleteBuffer() error {
}

func (s *batch) flushRelationshipDeleteBuffer() error {
if _, err := s.innerTransaction.tx.Exec(s.ctx, deleteEdgeWithIDStatement, s.relationshipDeletionBuffer); err != nil {
if _, err := s.innerTransaction.conn.Exec(s.ctx, deleteEdgeWithIDStatement, s.relationshipDeletionBuffer); err != nil {
return err
}

Expand Down Expand Up @@ -177,7 +177,7 @@ func (s *batch) flushNodeCreateBufferWithIDs() error {

if graphTarget, err := s.innerTransaction.getTargetGraph(); err != nil {
return err
} else if _, err := s.innerTransaction.tx.Exec(s.ctx, createNodeWithIDBatchStatement, graphTarget.ID, nodeIDs, kindIDSlices, properties); err != nil {
} else if _, err := s.innerTransaction.conn.Exec(s.ctx, createNodeWithIDBatchStatement, graphTarget.ID, nodeIDs, kindIDSlices, properties); err != nil {
return err
}

Expand Down Expand Up @@ -211,7 +211,7 @@ func (s *batch) flushNodeCreateBufferWithoutIDs() error {

if graphTarget, err := s.innerTransaction.getTargetGraph(); err != nil {
return err
} else if _, err := s.innerTransaction.tx.Exec(s.ctx, createNodeWithoutIDBatchStatement, graphTarget.ID, kindIDSlices, properties); err != nil {
} else if _, err := s.innerTransaction.conn.Exec(s.ctx, createNodeWithoutIDBatchStatement, graphTarget.ID, kindIDSlices, properties); err != nil {
return err
}

Expand All @@ -231,7 +231,7 @@ func (s *batch) flushNodeUpsertBatch(updates *sql.NodeUpdateBatch) error {
} else {
query := sql.FormatNodeUpsert(graphTarget, updates.IdentityProperties)

if rows, err := s.innerTransaction.tx.Query(s.ctx, query, parameters.Format(graphTarget)...); err != nil {
if rows, err := s.innerTransaction.conn.Query(s.ctx, query, parameters.Format(graphTarget)...); err != nil {
return err
} else {
defer rows.Close()
Expand Down Expand Up @@ -382,7 +382,7 @@ func (s *batch) flushRelationshipUpdateByBuffer(updates *sql.RelationshipUpdateB
} else {
query := sql.FormatRelationshipPartitionUpsert(graphTarget, updates.IdentityProperties)

if _, err := s.innerTransaction.tx.Exec(s.ctx, query, parameters.Format(graphTarget)...); err != nil {
if _, err := s.innerTransaction.conn.Exec(s.ctx, query, parameters.Format(graphTarget)...); err != nil {
return err
}
}
Expand Down Expand Up @@ -501,7 +501,7 @@ func (s *batch) flushRelationshipCreateBuffer() error {
return err
} else if graphTarget, err := s.innerTransaction.getTargetGraph(); err != nil {
return err
} else if _, err := s.innerTransaction.tx.Exec(s.ctx, createEdgeBatchStatement, graphTarget.ID, createBatch.startIDs, createBatch.endIDs, createBatch.edgeKindIDs, createBatch.edgePropertyBags); err != nil {
} else if _, err := s.innerTransaction.conn.Exec(s.ctx, createEdgeBatchStatement, graphTarget.ID, createBatch.startIDs, createBatch.endIDs, createBatch.edgeKindIDs, createBatch.edgePropertyBags); err != nil {
slog.Info(fmt.Sprintf("Num merged property bags: %d - Num edge keys: %d - StartID batch size: %d", len(batchBuilder.edgePropertiesIndex), len(batchBuilder.keyToEdgeID), len(batchBuilder.relationshipUpdateBatch.startIDs)))
return err
}
Expand Down
4 changes: 2 additions & 2 deletions packages/go/dawgs/drivers/pg/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func (s *Driver) ReadTransaction(ctx context.Context, txDelegate graph.Transacti
} else {
defer conn.Release()

return txDelegate(&transaction{
return txDelegate(&txWrapper{
schemaManager: s.schemaManager,
queryExecMode: cfg.QueryExecMode,
ctx: ctx,
Expand All @@ -152,7 +152,7 @@ func (s *Driver) WriteTransaction(ctx context.Context, txDelegate graph.Transact
} else {
defer conn.Release()

if tx, err := newTransaction(ctx, conn, s.schemaManager, cfg); err != nil {
if tx, err := newTXWrapper(ctx, conn, s.schemaManager, cfg, true); err != nil {
return err
} else {
defer tx.Close()
Expand Down
63 changes: 34 additions & 29 deletions packages/go/dawgs/drivers/pg/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (s inspectingDriver) QueryRow(ctx context.Context, sql string, arguments ..
return s.upstreamDriver.QueryRow(ctx, sql, arguments...)
}

type transaction struct {
type txWrapper struct {
schemaManager *SchemaManager
queryExecMode pgx.QueryExecMode
queryResultsFormat pgx.QueryResultFormats
Expand All @@ -69,23 +69,28 @@ type transaction struct {
targetSchemaSet bool
}

func newTransaction(ctx context.Context, conn *pgxpool.Conn, schemaManager *SchemaManager, cfg *Config) (*transaction, error) {
if pgxTx, err := conn.BeginTx(ctx, cfg.Options); err != nil {
return nil, err
} else {
return &transaction{
schemaManager: schemaManager,
queryExecMode: cfg.QueryExecMode,
queryResultsFormat: cfg.QueryResultFormats,
ctx: ctx,
conn: conn,
tx: pgxTx,
targetSchemaSet: false,
}, nil
func newTXWrapper(ctx context.Context, conn *pgxpool.Conn, schemaManager *SchemaManager, cfg *Config, allocateTransaction bool) (*txWrapper, error) {
wrapper := &txWrapper{
schemaManager: schemaManager,
queryExecMode: cfg.QueryExecMode,
queryResultsFormat: cfg.QueryResultFormats,
ctx: ctx,
conn: conn,
targetSchemaSet: false,
}

if allocateTransaction {
if pgxTx, err := conn.BeginTx(ctx, cfg.Options); err != nil {
return nil, err
} else {
wrapper.tx = pgxTx
}
}

return wrapper, nil
}

func (s *transaction) driver() driver {
func (s *txWrapper) driver() driver {
if s.tx != nil {
return inspectingDriver{
upstreamDriver: s.tx,
Expand All @@ -97,25 +102,25 @@ func (s *transaction) driver() driver {
}
}

func (s *transaction) GraphQueryMemoryLimit() size.Size {
func (s *txWrapper) GraphQueryMemoryLimit() size.Size {
return size.Gibibyte
}

func (s *transaction) WithGraph(schema graph.Graph) graph.Transaction {
func (s *txWrapper) WithGraph(schema graph.Graph) graph.Transaction {
s.targetSchema = schema
s.targetSchemaSet = true

return s
}

func (s *transaction) Close() {
func (s *txWrapper) Close() {
if s.tx != nil {
s.tx.Rollback(s.ctx)
s.tx = nil
}
}

func (s *transaction) getTargetGraph() (model.Graph, error) {
func (s *txWrapper) getTargetGraph() (model.Graph, error) {
if !s.targetSchemaSet {
// Look for a default graph target
if defaultGraph, hasDefaultGraph := s.schemaManager.DefaultGraph(); !hasDefaultGraph {
Expand All @@ -128,7 +133,7 @@ func (s *transaction) getTargetGraph() (model.Graph, error) {
return s.schemaManager.AssertGraph(s, s.targetSchema)
}

func (s *transaction) CreateNode(properties *graph.Properties, kinds ...graph.Kind) (*graph.Node, error) {
func (s *txWrapper) CreateNode(properties *graph.Properties, kinds ...graph.Kind) (*graph.Node, error) {
if graphTarget, err := s.getTargetGraph(); err != nil {
return nil, err
} else if kindIDSlice, err := s.schemaManager.AssertKinds(s.ctx, kinds); err != nil {
Expand All @@ -155,7 +160,7 @@ func (s *transaction) CreateNode(properties *graph.Properties, kinds ...graph.Ki
}
}

func (s *transaction) UpdateNode(node *graph.Node) error {
func (s *txWrapper) UpdateNode(node *graph.Node) error {
var (
properties = node.Properties
updateStatements []graph.Criteria
Expand Down Expand Up @@ -183,13 +188,13 @@ func (s *transaction) UpdateNode(node *graph.Node) error {
}, updateStatements...)
}

func (s *transaction) Nodes() graph.NodeQuery {
func (s *txWrapper) Nodes() graph.NodeQuery {
return &nodeQuery{
liveQuery: newLiveQuery(s.ctx, s, s.schemaManager),
}
}

func (s *transaction) CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) (*graph.Relationship, error) {
func (s *txWrapper) CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) (*graph.Relationship, error) {
if graphTarget, err := s.getTargetGraph(); err != nil {
return nil, err
} else if kindIDSlice, err := s.schemaManager.AssertKinds(s.ctx, graph.Kinds{kind}); err != nil {
Expand Down Expand Up @@ -218,7 +223,7 @@ func (s *transaction) CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, k
}
}

func (s *transaction) UpdateRelationship(relationship *graph.Relationship) error {
func (s *txWrapper) UpdateRelationship(relationship *graph.Relationship) error {
var (
modifiedProperties = relationship.Properties.ModifiedProperties()
deletedProperties = relationship.Properties.DeletedProperties()
Expand Down Expand Up @@ -261,13 +266,13 @@ func (s *transaction) UpdateRelationship(relationship *graph.Relationship) error
return err
}

func (s *transaction) Relationships() graph.RelationshipQuery {
func (s *txWrapper) Relationships() graph.RelationshipQuery {
return &relationshipQuery{
liveQuery: newLiveQuery(s.ctx, s, s.schemaManager),
}
}

func (s *transaction) query(query string, parameters map[string]any) (pgx.Rows, error) {
func (s *txWrapper) query(query string, parameters map[string]any) (pgx.Rows, error) {
queryArgs := []any{s.queryExecMode, s.queryResultsFormat}

if len(parameters) > 0 {
Expand All @@ -277,7 +282,7 @@ func (s *transaction) query(query string, parameters map[string]any) (pgx.Rows,
return s.driver().Query(s.ctx, query, queryArgs...)
}

func (s *transaction) Query(query string, parameters map[string]any) graph.Result {
func (s *txWrapper) Query(query string, parameters map[string]any) graph.Result {
if parsedQuery, err := frontend.ParseCypher(frontend.NewContext(), query); err != nil {
return graph.NewErrorResult(err)
} else if translated, err := translate.Translate(s.ctx, parsedQuery, s.schemaManager, parameters); err != nil {
Expand All @@ -289,7 +294,7 @@ func (s *transaction) Query(query string, parameters map[string]any) graph.Resul
}
}

func (s *transaction) Raw(query string, parameters map[string]any) graph.Result {
func (s *txWrapper) Raw(query string, parameters map[string]any) graph.Result {
if rows, err := s.query(query, parameters); err != nil {
return graph.NewErrorResult(err)
} else {
Expand All @@ -301,7 +306,7 @@ func (s *transaction) Raw(query string, parameters map[string]any) graph.Result
}
}

func (s *transaction) Commit() error {
func (s *txWrapper) Commit() error {
if s.tx != nil {
return s.tx.Commit(s.ctx)
}
Expand Down

0 comments on commit f578d5c

Please sign in to comment.