From ed2d94337454721fa6b6033c71ec0af2a6498815 Mon Sep 17 00:00:00 2001 From: Drew Kimball Date: Tue, 29 Aug 2023 19:22:35 -0600 Subject: [PATCH 1/3] plpgsql: add parser support for cursors This patch adds support in the PLpgSQL parser for the following commands related to cursors: `DECLARE`, `OPEN`, `FETCH`, `MOVE`, and `CLOSE`. The `OPEN ... FOR EXECUTE ...` syntax is not currently implemented. Informs #105254 Release note: None --- pkg/sql/opt/optbuilder/plpgsql.go | 9 +- pkg/sql/plpgsql/parser/lexer.go | 126 ++++---- pkg/sql/plpgsql/parser/plpgsql.y | 162 +++++----- pkg/sql/plpgsql/parser/testdata/decl_header | 13 +- pkg/sql/plpgsql/parser/testdata/stmt_close | 2 +- .../plpgsql/parser/testdata/stmt_fetch_move | 302 +++++++++++++++++- pkg/sql/plpgsql/parser/testdata/stmt_open | 37 ++- pkg/sql/sem/plpgsqltree/constants.go | 98 ------ pkg/sql/sem/plpgsqltree/statements.go | 127 +++++--- 9 files changed, 572 insertions(+), 304 deletions(-) diff --git a/pkg/sql/opt/optbuilder/plpgsql.go b/pkg/sql/opt/optbuilder/plpgsql.go index a9ee8f084aa6..7a2f18ee2849 100644 --- a/pkg/sql/opt/optbuilder/plpgsql.go +++ b/pkg/sql/opt/optbuilder/plpgsql.go @@ -152,7 +152,14 @@ func (b *plpgsqlBuilder) init( b.ob = ob b.colRefs = colRefs b.params = params - b.decls = block.Decls + for i := range block.Decls { + switch dec := block.Decls[i].(type) { + case *ast.Declaration: + b.decls = append(b.decls, *dec) + case *ast.CursorDeclaration: + panic(unimplemented.New("bound cursors", "bound cursor declarations are not yet supported.")) + } + } b.returnType = returnType b.varTypes = make(map[tree.Name]*types.T) for _, dec := range b.decls { diff --git a/pkg/sql/plpgsql/parser/lexer.go b/pkg/sql/plpgsql/parser/lexer.go index 442f85d5e0e3..517540e5d1ed 100644 --- a/pkg/sql/plpgsql/parser/lexer.go +++ b/pkg/sql/plpgsql/parser/lexer.go @@ -246,56 +246,6 @@ func (l *lexer) MakeDynamicExecuteStmt() *plpgsqltree.DynamicExecute { return ret } -func (l *lexer) ProcessForOpenCursor(nullCursorExplicitExpr bool) *plpgsqltree.Open { - openStmt := &plpgsqltree.Open{} - openStmt.CursorOptions = plpgsqltree.CursorOptionFastPlan.Mask() - - if nullCursorExplicitExpr { - if l.Peek().id == NO { - l.lastPos++ - if l.Peek().id == SCROLL { - openStmt.CursorOptions |= plpgsqltree.CursorOptionNoScroll.Mask() - l.lastPos++ - } - } else if l.Peek().id == SCROLL { - openStmt.CursorOptions |= plpgsqltree.CursorOptionScroll.Mask() - l.lastPos++ - } - - if l.Peek().id != FOR { - l.setErr(pgerror.New(pgcode.Syntax, "syntax error, expected \"FOR\"")) - return nil - } - - l.lastPos++ - if l.Peek().id == EXECUTE { - l.lastPos++ - dynamicQuery, endToken := l.ReadSqlExpressionStr2(USING, ';') - openStmt.DynamicQuery = dynamicQuery - l.lastPos++ - if endToken == USING { - // Continue reading for params for the sql expression till the ending - // token is not a comma. - openStmt.Params = make([]string, 0) - for { - param, endToken := l.ReadSqlExpressionStr2(',', ';') - openStmt.Params = append(openStmt.Params, param) - if endToken != ',' { - break - } - l.lastPos++ - } - } - } else { - openStmt.Query = l.ReadSqlExpressionStr(';') - } - } else { - // read_cursor_args() - openStmt.ArgQuery = "hello" - } - return openStmt -} - // ReadSqlExpressionStr returns the string from the l.lastPos till it sees // the terminator for the first time. The returned string is made by tokens // between the starting index (included) to the terminator (not included). @@ -360,6 +310,62 @@ func (l *lexer) readSQLConstruct( return startPos, endPos, terminatorMet } +func (l *lexer) MakeFetchOrMoveStmt(isMove bool) (plpgsqltree.Statement, error) { + if l.parser.Lookahead() != -1 { + // Push back the lookahead token so that it can be included. + l.PushBack(1) + } + prefix := "FETCH " + if isMove { + prefix = "MOVE " + } + sqlStr, terminator := l.ReadSqlConstruct(INTO, ';') + sqlStr = prefix + sqlStr + sqlStmt, err := parser.ParseOne(sqlStr) + if err != nil { + return nil, err + } + var cursor tree.CursorStmt + switch t := sqlStmt.AST.(type) { + case *tree.FetchCursor: + cursor = t.CursorStmt + case *tree.MoveCursor: + cursor = t.CursorStmt + default: + return nil, errors.Newf("invalid FETCH or MOVE syntax") + } + var target []plpgsqltree.Variable + if !isMove { + if terminator != INTO { + return nil, errors.Newf("invalid syntax for FETCH") + } + // Read past the INTO. + l.lastPos++ + startPos, endPos, _ := l.readSQLConstruct(';') + for pos := startPos; pos < endPos; pos += 2 { + tok := l.tokens[pos] + if tok.id != IDENT { + return nil, errors.Newf("\"%s\" is not a scalar variable", tok.str) + } + if pos+1 != endPos && l.tokens[pos+1].id != ',' { + return nil, errors.Newf("expected INTO target to be a comma-separated list") + } + variable := plpgsqltree.Variable(strings.TrimSpace(l.getStr(pos, pos+1))) + target = append(target, variable) + } + if len(target) == 0 { + return nil, errors.Newf("expected INTO target") + } + } + // Move past the semicolon. + l.lastPos++ + return &plpgsqltree.Fetch{ + Cursor: cursor, + Target: target, + IsMove: isMove, + }, nil +} + func (l *lexer) ReadSqlConstruct( terminator1 int, terminators ...int, ) (sqlStr string, terminatorMet int) { @@ -380,26 +386,6 @@ func (l *lexer) getStr(startPos, endPos int) string { return l.in[start:end] } -func (l *lexer) ProcessQueryForCursorWithoutExplicitExpr(openStmt *plpgsqltree.Open) { - l.lastPos++ - if int(l.Peek().id) == EXECUTE { - dynamicQuery, endToken := l.ReadSqlExpressionStr2(USING, ';') - openStmt.DynamicQuery = dynamicQuery - if endToken == USING { - var expr string - for { - expr, endToken = l.ReadSqlExpressionStr2(',', ';') - openStmt.Params = append(openStmt.Params, expr) - if endToken != ',' { - break - } - } - } - } else { - openStmt.Query = l.ReadSqlExpressionStr(';') - } -} - // Peek peeks func (l *lexer) Peek() plpgsqlSymType { if l.lastPos+1 < len(l.tokens) { diff --git a/pkg/sql/plpgsql/parser/plpgsql.y b/pkg/sql/plpgsql/parser/plpgsql.y index 9cfc31443304..5ddc6a4698af 100644 --- a/pkg/sql/plpgsql/parser/plpgsql.y +++ b/pkg/sql/plpgsql/parser/plpgsql.y @@ -2,6 +2,7 @@ package parser import ( + "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/scanner" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sem/plpgsqltree" @@ -124,7 +125,7 @@ func (u *plpgsqlSymUnion) open() *plpgsqltree.Open { func (u *plpgsqlSymUnion) expr() plpgsqltree.Expr { if u.val == nil { - return nil + return nil } return u.val.(plpgsqltree.Expr) } @@ -133,14 +134,6 @@ func (u *plpgsqlSymUnion) exprs() []plpgsqltree.Expr { return u.val.([]plpgsqltree.Expr) } -func (u *plpgsqlSymUnion) declaration() *plpgsqltree.Declaration { - return u.val.(*plpgsqltree.Declaration) -} - -func (u *plpgsqlSymUnion) declarations() []plpgsqltree.Declaration { - return u.val.([]plpgsqltree.Declaration) -} - func (u *plpgsqlSymUnion) raiseOption() *plpgsqltree.RaiseOption { return u.val.(*plpgsqltree.RaiseOption) } @@ -166,6 +159,14 @@ func (u *plpgsqlSymUnion) conditions() []plpgsqltree.Condition { return u.val.([]plpgsqltree.Condition) } +func (u *plpgsqlSymUnion) cursorScrollOption() tree.CursorScrollOption { + return u.val.(tree.CursorScrollOption) +} + +func (u *plpgsqlSymUnion) sqlStatement() tree.Statement { + return u.val.(tree.Statement) +} + %} /* * Basic non-keyword token types. These are hard-wired into the core lexer. @@ -315,7 +316,6 @@ func (u *plpgsqlSymUnion) conditions() []plpgsqltree.Condition { %type decl_datatype %type decl_collate -%type <*plpgsqltree.Open> open_stmt_processor %type expr_until_semi expr_until_paren %type expr_until_then expr_until_loop opt_expr_until_when %type opt_exitcond @@ -340,8 +340,8 @@ func (u *plpgsqlSymUnion) conditions() []plpgsqltree.Condition { %type stmt_commit stmt_rollback %type stmt_case stmt_foreach_a -%type <*plpgsqltree.Declaration> decl_stmt decl_statement -%type <[]plpgsqltree.Declaration> decl_sect opt_decl_stmts decl_stmts +%type decl_stmt decl_statement +%type <[]plpgsqltree.Statement> decl_sect opt_decl_stmts decl_stmts %type <[]plpgsqltree.Exception> exception_sect proc_exceptions %type <*plpgsqltree.Exception> proc_exception @@ -362,9 +362,7 @@ func (u *plpgsqlSymUnion) conditions() []plpgsqltree.Condition { %type format_expr %type <[]plpgsqltree.Expr> opt_format_exprs format_exprs -%type opt_scrollable - -%type <*plpgsqltree.Fetch> opt_fetch_direction +%type opt_scrollable %type <*tree.NumVal> opt_transaction_chain @@ -385,7 +383,7 @@ pl_block: opt_block_label decl_sect BEGIN proc_sect exception_sect END opt_label { $$.val = &plpgsqltree.Block{ Label: $1, - Decls: $2.declarations(), + Decls: $2.statements(), Body: $4.statements(), Exceptions: $5.exceptions(), } @@ -394,54 +392,46 @@ pl_block: opt_block_label decl_sect BEGIN proc_sect exception_sect END opt_label decl_sect: DECLARE opt_decl_stmts { - $$.val = $2.declarations() + $$.val = $2.statements() } | /* EMPTY */ { // Use a nil slice to indicate DECLARE was not used. - $$.val = []plpgsqltree.Declaration(nil) + $$.val = []plpgsqltree.Statement(nil) } ; opt_decl_stmts: decl_stmts { - $$.val = $1.declarations() + $$.val = $1.statements() } | /* EMPTY */ { - $$.val = []plpgsqltree.Declaration{} + $$.val = []plpgsqltree.Statement{} } ; decl_stmts: decl_stmts decl_stmt { - decs := $1.declarations() - dec := $2.declaration() - if dec == nil { - $$.val = decs - } else { - $$.val = append(decs, *dec) - } + decs := $1.statements() + dec := $2.statement() + $$.val = append(decs, dec) } | decl_stmt { - dec := $1.declaration() - if dec == nil { - $$.val = []plpgsqltree.Declaration{} - } else { - $$.val = []plpgsqltree.Declaration{*dec} - } + dec := $1.statement() + $$.val = []plpgsqltree.Statement{dec} } ; decl_stmt : decl_statement { - $$.val = $1.declaration() + $$.val = $1.statement() } | DECLARE { // This is to allow useless extra "DECLARE" keywords in the declare section. - $$.val = (*plpgsqltree.Declaration)(nil) + $$.val = (plpgsqltree.Statement)(nil) } // TODO(chengxiong): turn this block on and throw useful error if user // tries to put the block label just before BEGIN instead of before @@ -466,36 +456,48 @@ decl_statement: decl_varname decl_const decl_datatype decl_collate decl_notnull { return unimplemented(plpgsqllex, "alias for") } -| decl_varname opt_scrollable CURSOR decl_cursor_args decl_is_for decl_cursor_query ';' +| decl_varname opt_scrollable CURSOR decl_cursor_args decl_is_for decl_cursor_query { - return unimplemented(plpgsqllex, "cursor") + $$.val = &plpgsqltree.CursorDeclaration{ + Name: plpgsqltree.Variable($1), + Scroll: $2.cursorScrollOption(), + Query: $6.sqlStatement(), + } } ; opt_scrollable: { - return unimplemented(plpgsqllex, "cursor") + $$.val = tree.UnspecifiedScroll } | NO_SCROLL SCROLL { - return unimplemented(plpgsqllex, "cursor") + $$.val = tree.NoScroll } | SCROLL { - return unimplemented(plpgsqllex, "cursor") + $$.val = tree.Scroll } ; -decl_cursor_query: +decl_cursor_query: expr_until_semi ';' { - plpgsqllex.(*lexer).ReadSqlExpressionStr(';') + stmts, err := parser.Parse($1) + if err != nil { + return setErr(plpgsqllex, err) + } + if len(stmts) != 1 { + return setErr(plpgsqllex, errors.New("expected exactly one SQL statement for cursor")) + } + $$.val = stmts[0].AST } ; -decl_cursor_args: +decl_cursor_args: '(' { + return unimplemented(plpgsqllex, "cursor arguments") } -| '(' decl_cursor_arglist ')' +| /* EMPTY */ { } ; @@ -687,11 +689,17 @@ proc_stmt:pl_block ';' | stmt_getdiag { } | stmt_open - { } + { + $$.val = $1.statement() + } | stmt_fetch - { } + { + $$.val = $1.statement() + } | stmt_move - { } + { + $$.val = $1.statement() + } | stmt_close { $$.val = $1.statement() @@ -1247,35 +1255,54 @@ stmt_dynexecute: EXECUTE } ; -// TODO: change expr_until_semi to process_cursor_before_semi -stmt_open: OPEN IDENT open_stmt_processor ';' +stmt_open: OPEN IDENT ';' { - openCursorStmt := $3.open() - openCursorStmt.CursorName = $2 - $$.val = openCursorStmt + $$.val = &plpgsqltree.Open{CurVar: plpgsqltree.Variable($2)} } -; - -stmt_fetch: FETCH opt_fetch_direction IDENT INTO +| OPEN IDENT opt_scrollable FOR EXECUTE { - return unimplemented(plpgsqllex, "fetch") + return unimplemented(plpgsqllex, "cursor for execute") + } +| OPEN IDENT opt_scrollable FOR expr_until_semi ';' + { + stmts, err := parser.Parse($5) + if err != nil { + return setErr(plpgsqllex, err) + } + if len(stmts) != 1 { + return setErr(plpgsqllex, errors.New("expected exactly one SQL statement for cursor")) + } + $$.val = &plpgsqltree.Open{ + CurVar: plpgsqltree.Variable($2), + Scroll: $3.cursorScrollOption(), + Query: stmts[0].AST, + } } ; -stmt_move: MOVE opt_fetch_direction IDENT ';' +stmt_fetch: FETCH { - return unimplemented(plpgsqllex, "move") + fetch, err := plpgsqllex.(*lexer).MakeFetchOrMoveStmt(false) + if err != nil { + return setErr(plpgsqllex, err) + } + $$.val = fetch } ; -opt_fetch_direction: +stmt_move: MOVE { - return unimplemented(plpgsqllex, "fetch direction") + move, err := plpgsqllex.(*lexer).MakeFetchOrMoveStmt(true) + if err != nil { + return setErr(plpgsqllex, err) + } + $$.val = move } +; -stmt_close: CLOSE cursor_variable ';' +stmt_close: CLOSE IDENT ';' { - $$.val = &plpgsqltree.Close{} + $$.val = &plpgsqltree.Close{CurVar: plpgsqltree.Variable($2)} } ; @@ -1305,12 +1332,6 @@ AND CHAIN | /* EMPTY */ { } -cursor_variable: IDENT - { - unimplemented(plpgsqllex, "cursor variable") - } -; - exception_sect: /* EMPTY */ { $$.val = []plpgsqltree.Exception(nil) @@ -1364,11 +1385,6 @@ proc_condition: any_identifier } ; -open_stmt_processor: - { - $$.val = plpgsqllex.(*lexer).ProcessForOpenCursor(true) - } - expr_until_semi: { $$ = plpgsqllex.(*lexer).ReadSqlExpressionStr(';') diff --git a/pkg/sql/plpgsql/parser/testdata/decl_header b/pkg/sql/plpgsql/parser/testdata/decl_header index 086024b6f5fe..3cf7d9132b35 100644 --- a/pkg/sql/plpgsql/parser/testdata/decl_header +++ b/pkg/sql/plpgsql/parser/testdata/decl_header @@ -40,10 +40,21 @@ END ---- at or near ";": syntax error: unimplemented: this syntax +parse +DECLARE + var1 CURSOR FOR SELECT * FROM t1 WHERE id = arg1; +BEGIN +END +---- +DECLARE +var1 CURSOR FOR SELECT * FROM t1 WHERE id = arg1; +BEGIN +END + parse DECLARE var1 NO SCROLL CURSOR (arg1 INTEGER) FOR SELECT * FROM t1 WHERE id = arg1; BEGIN END ---- -at or near "scroll": syntax error: unimplemented: this syntax +at or near "(": syntax error: unimplemented: this syntax diff --git a/pkg/sql/plpgsql/parser/testdata/stmt_close b/pkg/sql/plpgsql/parser/testdata/stmt_close index b61e63fc98bf..90edb47193a7 100644 --- a/pkg/sql/plpgsql/parser/testdata/stmt_close +++ b/pkg/sql/plpgsql/parser/testdata/stmt_close @@ -6,5 +6,5 @@ END ---- DECLARE BEGIN -CLOSE a cursor +CLOSE some_cursor; END diff --git a/pkg/sql/plpgsql/parser/testdata/stmt_fetch_move b/pkg/sql/plpgsql/parser/testdata/stmt_fetch_move index ac26c7b4b811..bceef82ba3b2 100644 --- a/pkg/sql/plpgsql/parser/testdata/stmt_fetch_move +++ b/pkg/sql/plpgsql/parser/testdata/stmt_fetch_move @@ -4,7 +4,10 @@ BEGIN MOVE NEXT FROM emp_cur; END ---- -at or near "move": syntax error: unimplemented: this syntax +DECLARE +BEGIN +MOVE 1 FROM emp_cur; +END parse DECLARE @@ -12,7 +15,10 @@ BEGIN MOVE PRIOR FROM var; END ---- -at or near "move": syntax error: unimplemented: this syntax +DECLARE +BEGIN +MOVE -1 FROM var; +END parse DECLARE @@ -20,7 +26,10 @@ BEGIN FETCH NEXT FROM emp_cur INTO x,y; END ---- -at or near "fetch": syntax error: unimplemented: this syntax +DECLARE +BEGIN +FETCH 1 FROM emp_cur INTO x, y; +END parse DECLARE @@ -28,7 +37,10 @@ BEGIN FETCH emp_cur INTO x,y; END ---- -at or near "fetch": syntax error: unimplemented: this syntax +DECLARE +BEGIN +FETCH 1 FROM emp_cur INTO x, y; +END parse DECLARE @@ -36,4 +48,284 @@ BEGIN FETCH ABSOLUTE 2 FROM emp_cur INTO x,y; END ---- -at or near "fetch": syntax error: unimplemented: this syntax +DECLARE +BEGIN +FETCH ABSOLUTE 2 FROM emp_cur INTO x, y; +END + +parse +DECLARE +BEGIN +FETCH emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH 1 FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH emp_cur INTO; +END +---- +at or near "into": syntax error: expected INTO target + +parse +DECLARE +BEGIN +FETCH emp_cur; +END +---- +at or near "emp_cur": syntax error: invalid syntax for FETCH + +parse +DECLARE +BEGIN +MOVE NEXT FROM emp_cur INTO x, y; +END +---- +at or near ";": at or near "x": syntax error + +parse +DECLARE +BEGIN +MOVE NEXT FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE 1 FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE PRIOR FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE -1 FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE FIRST FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE FIRST FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE LAST FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE LAST FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE ABSOLUTE 5 FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE ABSOLUTE 5 FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE FIRST FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE FIRST FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE RELATIVE 3 FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE RELATIVE 3 FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE FORWARD 3 FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE 3 FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE BACKWARD 3 FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE -3 FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE FORWARD ALL FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE ALL FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE BACKWARD ALL FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE BACKWARD ALL FROM emp_cur; +END + +parse +DECLARE +BEGIN +FETCH NEXT FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH 1 FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH PRIOR FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH -1 FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH FIRST FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH FIRST FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH LAST FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH LAST FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH ABSOLUTE 5 FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH ABSOLUTE 5 FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH FIRST FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH FIRST FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH RELATIVE 3 FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH RELATIVE 3 FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH FORWARD 3 FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH 3 FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH BACKWARD 3 FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH -3 FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH FORWARD ALL FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH ALL FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH BACKWARD ALL FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH BACKWARD ALL FROM emp_cur INTO x; +END diff --git a/pkg/sql/plpgsql/parser/testdata/stmt_open b/pkg/sql/plpgsql/parser/testdata/stmt_open index 73a061c8e2a3..d91b472f6352 100644 --- a/pkg/sql/plpgsql/parser/testdata/stmt_open +++ b/pkg/sql/plpgsql/parser/testdata/stmt_open @@ -1,22 +1,51 @@ parse DECLARE BEGIN -OPEN curs1 NO SCROLL FOR SELECT * FROM foo WHERE key = mykey; +OPEN curs1; END ---- DECLARE BEGIN -OPEN curs1 NO SCROLL FOR SELECT * FROM foo WHERE key = mykey +OPEN curs1; END +parse +DECLARE +BEGIN +OPEN curs1 FOR SELECT * FROM foo WHERE key = mykey; +END +---- +DECLARE +BEGIN +OPEN curs1 FOR SELECT * FROM foo WHERE key = mykey; +END parse DECLARE BEGIN -OPEN curs2 SCROLL FOR EXECUTE SELECT $1, $2 FROM foo WHERE key = mykey USING hello, jojo; +OPEN curs1 SCROLL FOR SELECT * FROM foo WHERE key = mykey; +END +---- +DECLARE +BEGIN +OPEN curs1 SCROLL FOR SELECT * FROM foo WHERE key = mykey; +END + +parse +DECLARE +BEGIN +OPEN curs1 NO SCROLL FOR SELECT * FROM foo WHERE key = mykey; END ---- DECLARE BEGIN -OPEN curs2 SCROLL FOR EXECUTE SELECT $1, $2 FROM foo WHERE key = mykey USING [hello jojo] +OPEN curs1 NO SCROLL FOR SELECT * FROM foo WHERE key = mykey; +END + +parse +DECLARE +BEGIN +OPEN curs2 SCROLL FOR EXECUTE SELECT $1, $2 FROM foo WHERE key = mykey USING hello, jojo; END +---- +at or near "execute": syntax error: unimplemented: this syntax diff --git a/pkg/sql/sem/plpgsqltree/constants.go b/pkg/sql/sem/plpgsqltree/constants.go index 8c8c389406cf..67ebc342990c 100644 --- a/pkg/sql/sem/plpgsqltree/constants.go +++ b/pkg/sql/sem/plpgsqltree/constants.go @@ -79,102 +79,4 @@ func (k GetDiagnosticsKind) String() string { return "SCHEMA_NAME" } panic(errors.AssertionFailedf("unknown diagnostics kind")) - -} - -// FetchDirection represents the direction clause passed into a fetch statement. -type FetchDirection int - -// CursorOption represents a cursor option, which describes how a cursor will -// behave. -type CursorOption uint32 - -const ( - // CursorOptionNone - CursorOptionNone CursorOption = iota - // CursorOptionBinary describes cursors that return data in binary form. - CursorOptionBinary - // CursorOptionScroll describes cursors that can retrieve rows in - // non-sequential fashion. - CursorOptionScroll - // CursorOptionNoScroll describes cursors that can not retrieve rows in - // non-sequential fashion. - CursorOptionNoScroll - // CursorOptionInsensitive describes cursors that can't see changes to - // done to data in same txn. - CursorOptionInsensitive - // CursorOPtionAsensitive describes cursors that may be able to see - // changes to done to data in same txn. - CursorOPtionAsensitive - // CursorOptionHold describes cursors that can be used after a txn that it - // was created in commits. - CursorOptionHold - // CursorOptionFastPlan describes cursors that can not be used after a txn - // that it was created in commits. - CursorOptionFastPlan - // CursorOptionGenericPlan describes cursors that uses a generic plan. - CursorOptionGenericPlan - // CursorOptionCustomPlan describes cursors that uses a custom plan. - CursorOptionCustomPlan - // CursorOptionParallelOK describes cursors that allows parallel queries. - CursorOptionParallelOK -) - -// String implements the fmt.Stringer interface. -func (o CursorOption) String() string { - switch o { - case CursorOptionNoScroll: - return "NO SCROLL" - case CursorOptionScroll: - return "SCROLL" - case CursorOptionFastPlan: - return "" - // TODO(jane): implement string representation for other opts. - default: - return "NOT_IMPLEMENTED_OPT" - } -} - -// Mask returns the bitmask for a given cursor option. -func (o CursorOption) Mask() uint32 { - return 1 << o -} - -// IsSetIn returns true if this cursor option is set in the supplied bitfield. -func (o CursorOption) IsSetIn(bits uint32) bool { - return bits&o.Mask() != 0 -} - -type cursorOptionList []CursorOption - -// ToBitField returns the bitfield representation of a list of cursor options. -func (ol cursorOptionList) ToBitField() uint32 { - var ret uint32 - for _, o := range ol { - ret |= o.Mask() - } - return ret -} - -// OptListFromBitField returns a list of cursor option to be printed. -func OptListFromBitField(m uint32) cursorOptionList { - ret := cursorOptionList{} - opts := []CursorOption{ - CursorOptionBinary, - CursorOptionScroll, - CursorOptionNoScroll, - CursorOptionInsensitive, - CursorOPtionAsensitive, - CursorOptionHold, - CursorOptionFastPlan, - CursorOptionGenericPlan, - CursorOptionCustomPlan, - CursorOptionParallelOK, - } - for _, opt := range opts { - if opt.IsSetIn(m) { - ret = append(ret, opt) - } - } - return ret } diff --git a/pkg/sql/sem/plpgsqltree/statements.go b/pkg/sql/sem/plpgsqltree/statements.go index 68da4883dce2..02dc14c36787 100644 --- a/pkg/sql/sem/plpgsqltree/statements.go +++ b/pkg/sql/sem/plpgsqltree/statements.go @@ -12,6 +12,7 @@ package plpgsqltree import ( "fmt" + "strconv" "strings" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" @@ -57,7 +58,7 @@ func (s *StatementImpl) plpgsqlStmt() {} type Block struct { StatementImpl Label string - Decls []Declaration + Decls []Statement Body []Statement Exceptions []Exception } @@ -134,6 +135,34 @@ func (s *Declaration) WalkStmt(visitor StatementVisitor) { visitor.Visit(s) } +type CursorDeclaration struct { + StatementImpl + Name Variable + Scroll tree.CursorScrollOption + Query tree.Statement +} + +func (s *CursorDeclaration) Format(ctx *tree.FmtCtx) { + ctx.WriteString(string(s.Name)) + switch s.Scroll { + case tree.Scroll: + ctx.WriteString(" SCROLL") + case tree.NoScroll: + ctx.WriteString(" NO SCROLL") + } + ctx.WriteString(" CURSOR FOR ") + s.Query.Format(ctx) + ctx.WriteString(";\n") +} + +func (s *CursorDeclaration) PlpgSQLStatementTag() string { + return "decl_cursor_stmt" +} + +func (s *CursorDeclaration) WalkStmt(visitor StatementVisitor) { + visitor.Visit(s) +} + // stmt_assign type Assignment struct { Statement @@ -850,50 +879,25 @@ func (s *GetDiagnostics) WalkStmt(visitor StatementVisitor) { // stmt_open type Open struct { StatementImpl - CurVar int // TODO(drewk): this could just a Variable - CursorOptions uint32 - // TODO(jane): This is temporary and we should remove it and use CurVar. - CursorName string - WithExplicitExpr bool - // TODO(jane): Should be Expr - ArgQuery string - // TODO(jane): Should be Expr - Query string - // TODO(jane): Should be Expr - DynamicQuery string - // TODO(jane): Should be []Expr - Params []string + CurVar Variable + Scroll tree.CursorScrollOption + Query tree.Statement } func (s *Open) Format(ctx *tree.FmtCtx) { - ctx.WriteString( - fmt.Sprintf( - "OPEN %s ", - s.CursorName, - )) - - opts := OptListFromBitField(s.CursorOptions) - for _, opt := range opts { - if opt.String() != "" { - ctx.WriteString(fmt.Sprintf("%s ", opt.String())) - } + ctx.WriteString("OPEN ") + s.CurVar.Format(ctx) + switch s.Scroll { + case tree.Scroll: + ctx.WriteString(" SCROLL") + case tree.NoScroll: + ctx.WriteString(" NO SCROLL") } - if !s.WithExplicitExpr { - ctx.WriteString("FOR ") - if s.DynamicQuery != "" { - // TODO(drewk): Make sure placeholders are properly printed - ctx.WriteString(fmt.Sprintf("EXECUTE %s ", s.DynamicQuery)) - if len(s.Params) != 0 { - // TODO(drewk): Dont print instances of multiple params with brackets `[...]` - ctx.WriteString(fmt.Sprintf("USING %s", s.Params)) - } - } else { - ctx.WriteString(s.Query) - } - } else { - ctx.WriteString(s.ArgQuery) + if s.Query != nil { + ctx.WriteString(" FOR ") + s.Query.Format(ctx) } - ctx.WriteString("\n") + ctx.WriteString(";\n") } func (s *Open) PlpgSQLStatementTag() string { @@ -908,16 +912,37 @@ func (s *Open) WalkStmt(visitor StatementVisitor) { // stmt_move (where IsMove = true) type Fetch struct { StatementImpl - Target Variable - CurVar int // TODO(drewk): this could just a Variable - Direction FetchDirection - HowMany int64 - Expr Expr - IsMove bool - ReturnsMultiRows bool + Cursor tree.CursorStmt + Target []Variable + IsMove bool } func (s *Fetch) Format(ctx *tree.FmtCtx) { + if s.IsMove { + ctx.WriteString("MOVE ") + } else { + ctx.WriteString("FETCH ") + } + if dir := s.Cursor.FetchType.String(); dir != "" { + ctx.WriteString(dir) + ctx.WriteString(" ") + } + if s.Cursor.FetchType.HasCount() { + ctx.WriteString(strconv.Itoa(int(s.Cursor.Count))) + ctx.WriteString(" ") + } + ctx.WriteString("FROM ") + s.Cursor.Name.Format(ctx) + if s.Target != nil { + ctx.WriteString(" INTO ") + for i := range s.Target { + if i > 0 { + ctx.WriteString(", ") + } + s.Target[i].Format(ctx) + } + } + ctx.WriteString(";\n") } func (s *Fetch) PlpgSQLStatementTag() string { @@ -934,13 +959,13 @@ func (s *Fetch) WalkStmt(visitor StatementVisitor) { // stmt_close type Close struct { StatementImpl - CurVar int // TODO(drewk): this could just a Variable + CurVar Variable } func (s *Close) Format(ctx *tree.FmtCtx) { - // TODO(drewk): Pretty- Print the cursor identifier - ctx.WriteString("CLOSE a cursor\n") - + ctx.WriteString("CLOSE ") + s.CurVar.Format(ctx) + ctx.WriteString(";\n") } func (s *Close) PlpgSQLStatementTag() string { From 3ee8f897290fd762964c73b3e3ea6cca1d45df87 Mon Sep 17 00:00:00 2001 From: Drew Kimball Date: Fri, 15 Sep 2023 02:56:57 -0600 Subject: [PATCH 2/3] plpgsql: add execution support for OPEN statements This patch adds support for executing PLpgSQL OPEN statements, which open a SQL cursor in the current transaction. The name of the cursor is supplied through a PLpgSQL variable. Since the `REFCURSOR` type hasn't been implemented yet, this patch uses `STRING` in the meantime. Limitations that will be lifted in future PRs: 1. Unnamed cursor declarations are not supported. If a cursor is opened with no name supplied, a name should be automatically generated. 2. Bound cursors are not yet supported. It should be possible to declare a cursor in the `DECLARE` block with the query already defined, at which point it can be opened with `OPEN ;`. 3. A cursor cannot be opened in a routine with an exception block. This is because correct handling of this case is waiting on separate work to implement rollback of changes to database state on exceptions. Informs #109709 Release note (sql change): Added initial support for executing the PLpgSQL `OPEN` statement, which allows a PLpgSQL routine to create a cursor. Currently, opening bound or unnamed cursors is not supported. In addition, `OPEN` statements cannot be used in a routine with an exception block. --- .../tests/3node-tenant/generated_test.go | 7 + pkg/sql/conn_executor.go | 11 +- .../testdata/logic_test/plpgsql_cursor | 400 ++++++++++++++++++ .../tests/fakedist-disk/generated_test.go | 7 + .../tests/fakedist-vec-off/generated_test.go | 7 + .../tests/fakedist/generated_test.go | 7 + .../generated_test.go | 7 + .../local-mixed-22.2-23.1/generated_test.go | 7 + .../tests/local-vec-off/generated_test.go | 7 + .../logictest/tests/local/generated_test.go | 7 + pkg/sql/opt/exec/execbuilder/relational.go | 1 + pkg/sql/opt/exec/execbuilder/scalar.go | 5 + pkg/sql/opt/memo/expr.go | 7 + pkg/sql/opt/memo/expr_format.go | 6 + pkg/sql/opt/memo/interner.go | 27 ++ pkg/sql/opt/norm/inline_funcs.go | 5 +- pkg/sql/opt/optbuilder/plpgsql.go | 60 ++- pkg/sql/opt/optbuilder/scope_column.go | 9 + pkg/sql/opt/optbuilder/testdata/udf_plpgsql | 219 ++++++++++ pkg/sql/routine.go | 137 +++++- pkg/sql/sem/tree/routine.go | 22 + pkg/sql/sql_cursor.go | 20 +- 22 files changed, 960 insertions(+), 25 deletions(-) create mode 100644 pkg/sql/logictest/testdata/logic_test/plpgsql_cursor diff --git a/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go b/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go index 8c7ae216c92c..49ed623311c2 100644 --- a/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go +++ b/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go @@ -1381,6 +1381,13 @@ func TestTenantLogic_pgoidtype( runLogicTest(t, "pgoidtype") } +func TestTenantLogic_plpgsql_cursor( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "plpgsql_cursor") +} + func TestTenantLogic_poison_after_push( t *testing.T, ) { diff --git a/pkg/sql/conn_executor.go b/pkg/sql/conn_executor.go index 9ab7ad6405a0..d86314936247 100644 --- a/pkg/sql/conn_executor.go +++ b/pkg/sql/conn_executor.go @@ -1214,13 +1214,16 @@ func (ex *connExecutor) close(ctx context.Context, closeType closeType) { txnEvType = txnRollback } - // Close all portals, otherwise there will be leftover bytes. + // Close all portals and cursors, otherwise there will be leftover bytes. ex.extraTxnState.prepStmtsNamespace.closeAllPortals( ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, ) ex.extraTxnState.prepStmtsNamespaceAtTxnRewindPos.closeAllPortals( ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, ) + if err := ex.extraTxnState.sqlCursors.closeAll(false /* errorOnWithHold */); err != nil { + log.Warningf(ctx, "error closing cursors: %v", err) + } var payloadErr error if closeType == normalClose { @@ -1271,7 +1274,8 @@ func (ex *connExecutor) close(ctx context.Context, closeType closeType) { } if closeType != panicClose { - // Close all statements, prepared portals, and cursors. + // Close all statements and prepared portals. The cursors have already been + // closed. ex.extraTxnState.prepStmtsNamespace.resetToEmpty( ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, ) @@ -1279,9 +1283,6 @@ func (ex *connExecutor) close(ctx context.Context, closeType closeType) { ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, ) ex.extraTxnState.prepStmtsNamespaceMemAcc.Close(ctx) - if err := ex.extraTxnState.sqlCursors.closeAll(false /* errorOnWithHold */); err != nil { - log.Warningf(ctx, "error closing cursors: %v", err) - } } if ex.sessionTracing.Enabled() { diff --git a/pkg/sql/logictest/testdata/logic_test/plpgsql_cursor b/pkg/sql/logictest/testdata/logic_test/plpgsql_cursor new file mode 100644 index 000000000000..d8d4684a3794 --- /dev/null +++ b/pkg/sql/logictest/testdata/logic_test/plpgsql_cursor @@ -0,0 +1,400 @@ +statement ok +CREATE TABLE xy (x INT, y INT); +INSERT INTO xy VALUES (1, 2), (3, 4); + +statement ok +CREATE TABLE kv (k INT PRIMARY KEY, v INT); +INSERT INTO kv VALUES (1, 2), (3, 4); + +# Testing OPEN statements. +statement ok +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query I +FETCH FORWARD 3 FROM foo; +---- +1 + +statement ok +ABORT; + +statement error pgcode 34000 pq: cursor \"foo\" does not exist +FETCH FORWARD 3 FROM foo; + +statement ok +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + x INT := 10; + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT x; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query I +FETCH FORWARD 3 FROM foo; +---- +10 + +# TODO(drewk): postgres returns an ambiguous column error here by default, +# although it can be configured to prefer either the variable or the column. +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + x INT := 10; + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT * FROM xy WHERE xy.x = x; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query II rowsort +FETCH FORWARD 10 FROM foo; +---- +1 2 +3 4 + +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + i INT := 3; + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT * FROM xy WHERE x = i; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query II +FETCH FORWARD 3 FROM foo; +---- +3 4 + +# It should be possible to fetch from the cursor incrementally. +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs NO SCROLL FOR SELECT * FROM (VALUES (1, 2), (3, 4), (5, 6), (7, 8)) v(a, b); + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query II rowsort +FETCH FORWARD 1 FROM foo; +---- +1 2 + +query II rowsort +FETCH FORWARD 2 FROM foo; +---- +3 4 +5 6 + +query II rowsort +FETCH FORWARD 3 FROM foo; +---- +7 8 + +# Cursor with NO SCROLL option. +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs NO SCROLL FOR SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query I +FETCH FORWARD 3 FROM foo; +---- +1 + +# Cursor with empty-string name. +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := ''; + BEGIN + OPEN curs FOR SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query I +FETCH FORWARD 3 FROM ""; +---- +1 + +# Multiple cursors. +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + curs2 STRING := 'bar'; + curs3 STRING := 'baz'; + BEGIN + OPEN curs FOR SELECT 1; + OPEN curs2 FOR SELECT 2; + OPEN curs3 FOR SELECT 3; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query I +FETCH FORWARD 3 FROM foo; +---- +1 + +query I +FETCH FORWARD 3 FROM bar; +---- +2 + +query I +FETCH FORWARD 3 FROM baz; +---- +3 + +# The cursor should reflect changes to the database state that occur before +# it is opened, but not those that happen after it is opened. +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + curs2 STRING := 'bar'; + curs3 STRING := 'baz'; + BEGIN + OPEN curs FOR SELECT * FROM xy WHERE x = 99; + INSERT INTO xy VALUES (99, 99); + OPEN curs2 FOR SELECT * FROM xy WHERE x = 99; + UPDATE xy SET y = 100 WHERE x = 99; + OPEN curs3 FOR SELECT * FROM xy WHERE x = 99; + DELETE FROM xy WHERE x = 99; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query II +FETCH FORWARD 3 FROM foo; +---- + +query II +FETCH FORWARD 3 FROM bar; +---- +99 99 + +query II +FETCH FORWARD 3 FROM baz; +---- +99 100 + +query II rowsort +SELECT * FROM xy; +---- +1 2 +3 4 + +statement ok +ABORT; + +# It is possible to use the OPEN statement in an implicit transaction, but the +# cursor is closed at the end of the transaction when the statement execution +# finishes. So, until FETCH is implemented, we can't actually read from the +# cursor. +statement ok +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +SELECT f(); + +statement error pgcode 34000 pq: cursor \"foo\" does not exist +FETCH FORWARD 5 FROM foo; + +statement error pgcode 0A000 pq: unimplemented: DECLARE SCROLL CURSOR +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs SCROLL FOR SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; + +statement error pgcode 0A000 pq: unimplemented: bound cursor declarations are not yet supported. +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs CURSOR FOR SELECT 1; + BEGIN + OPEN curs; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; + +statement error pgcode 42P11 pq: cannot open INSERT query as cursor +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs FOR INSERT INTO xy VALUES (1, 1); + RETURN 0; + END +$$ LANGUAGE PLpgSQL; + +statement error pgcode 0A000 pq: unimplemented: CTE usage inside a function definition +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + i INT := 3; + curs STRING := 'foo'; + BEGIN + OPEN curs FOR WITH foo AS (SELECT * FROM xy WHERE x = i) SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; + +# TODO(drewk): once CTEs in routines are supported, the error should be: +# pgcode 0A000 pq: DECLARE CURSOR must not contain data-modifying statements in WITH +statement error pgcode 0A000 pq: unimplemented: CTE usage inside a function definition +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + i INT := 3; + curs STRING := 'foo'; + BEGIN + OPEN curs FOR WITH foo AS (INSERT INTO xy VALUES (1, 1) RETURNING x) SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; + +statement error pgcode 42601 pq: \"curs\" is not a known variable +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + i INT := 3; + BEGIN + OPEN curs FOR WITH foo AS (SELECT * FROM xy WHERE x = i) SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; + +statement ok +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING; + BEGIN + OPEN curs FOR SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; + +statement error pgcode 0A000 pq: unimplemented: opening an unnamed cursor is not yet supported +SELECT f(); + +statement ok +ABORT; + +statement error pgcode 0A000 pq: unimplemented: opening a cursor in a routine with an exception block is not yet supported +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING; + BEGIN + OPEN curs FOR SELECT 1; + RETURN 0; + EXCEPTION + WHEN division_by_zero THEN + RETURN -1; + END +$$ LANGUAGE PLpgSQL; +BEGIN; + +statement ok +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT 1 // 0; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; + +statement error pgcode 22012 pq: division by zero +SELECT f(); + +# Conflict with an existing cursor. +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; + +statement ok +DECLARE foo CURSOR FOR SELECT 100; + +statement error pgcode 42P03 pq: cursor \"foo\" already exists +SELECT f(); + +# Conflict between OPEN statements within the same routine. +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + curs2 STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT 1; + OPEN curs2 FOR SELECT 2; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; + +statement error pgcode 42P03 pq: cursor \"foo\" already exists +SELECT f(); diff --git a/pkg/sql/logictest/tests/fakedist-disk/generated_test.go b/pkg/sql/logictest/tests/fakedist-disk/generated_test.go index 88e18d338e3c..2f4e89470795 100644 --- a/pkg/sql/logictest/tests/fakedist-disk/generated_test.go +++ b/pkg/sql/logictest/tests/fakedist-disk/generated_test.go @@ -1359,6 +1359,13 @@ func TestLogic_pgoidtype( runLogicTest(t, "pgoidtype") } +func TestLogic_plpgsql_cursor( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "plpgsql_cursor") +} + func TestLogic_poison_after_push( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go b/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go index 926db53b5f10..992eb5c5786b 100644 --- a/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go +++ b/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go @@ -1359,6 +1359,13 @@ func TestLogic_pgoidtype( runLogicTest(t, "pgoidtype") } +func TestLogic_plpgsql_cursor( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "plpgsql_cursor") +} + func TestLogic_poison_after_push( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/fakedist/generated_test.go b/pkg/sql/logictest/tests/fakedist/generated_test.go index f49d57d715cf..503ac7fe8b4f 100644 --- a/pkg/sql/logictest/tests/fakedist/generated_test.go +++ b/pkg/sql/logictest/tests/fakedist/generated_test.go @@ -1373,6 +1373,13 @@ func TestLogic_pgoidtype( runLogicTest(t, "pgoidtype") } +func TestLogic_plpgsql_cursor( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "plpgsql_cursor") +} + func TestLogic_poison_after_push( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go b/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go index 69939d73d853..0baf52e63571 100644 --- a/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go +++ b/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go @@ -1345,6 +1345,13 @@ func TestLogic_pgoidtype( runLogicTest(t, "pgoidtype") } +func TestLogic_plpgsql_cursor( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "plpgsql_cursor") +} + func TestLogic_poison_after_push( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/local-mixed-22.2-23.1/generated_test.go b/pkg/sql/logictest/tests/local-mixed-22.2-23.1/generated_test.go index f642f9455c1c..39aa14d9c6af 100644 --- a/pkg/sql/logictest/tests/local-mixed-22.2-23.1/generated_test.go +++ b/pkg/sql/logictest/tests/local-mixed-22.2-23.1/generated_test.go @@ -1338,6 +1338,13 @@ func TestLogic_pgoidtype( runLogicTest(t, "pgoidtype") } +func TestLogic_plpgsql_cursor( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "plpgsql_cursor") +} + func TestLogic_poison_after_push( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/local-vec-off/generated_test.go b/pkg/sql/logictest/tests/local-vec-off/generated_test.go index 29fe95596431..b12841aafb75 100644 --- a/pkg/sql/logictest/tests/local-vec-off/generated_test.go +++ b/pkg/sql/logictest/tests/local-vec-off/generated_test.go @@ -1373,6 +1373,13 @@ func TestLogic_pgoidtype( runLogicTest(t, "pgoidtype") } +func TestLogic_plpgsql_cursor( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "plpgsql_cursor") +} + func TestLogic_poison_after_push( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/local/generated_test.go b/pkg/sql/logictest/tests/local/generated_test.go index 523563e3db47..78b0a226c816 100644 --- a/pkg/sql/logictest/tests/local/generated_test.go +++ b/pkg/sql/logictest/tests/local/generated_test.go @@ -1499,6 +1499,13 @@ func TestLogic_pgoidtype( runLogicTest(t, "pgoidtype") } +func TestLogic_plpgsql_cursor( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "plpgsql_cursor") +} + func TestLogic_poison_after_push( t *testing.T, ) { diff --git a/pkg/sql/opt/exec/execbuilder/relational.go b/pkg/sql/opt/exec/execbuilder/relational.go index b1e2007e95f5..b9ec50bd5257 100644 --- a/pkg/sql/opt/exec/execbuilder/relational.go +++ b/pkg/sql/opt/exec/execbuilder/relational.go @@ -3168,6 +3168,7 @@ func (b *Builder) buildCall(c *memo.CallExpr) (execPlan, error) { udf.TailCall, true, /* procedure */ nil, /* exceptionHandler */ + nil, /* cursorDeclaration */ ) var ep execPlan diff --git a/pkg/sql/opt/exec/execbuilder/scalar.go b/pkg/sql/opt/exec/execbuilder/scalar.go index 545d15321de1..b8a961483907 100644 --- a/pkg/sql/opt/exec/execbuilder/scalar.go +++ b/pkg/sql/opt/exec/execbuilder/scalar.go @@ -703,6 +703,7 @@ func (b *Builder) buildExistsSubquery( false, /* tailCall */ false, /* procedure */ nil, /* exceptionHandler */ + nil, /* cursorDeclaration */ ), tree.DBoolFalse, }, types.Bool), nil @@ -821,6 +822,7 @@ func (b *Builder) buildSubquery( false, /* tailCall */ false, /* procedure */ nil, /* exceptionHandler */ + nil, /* cursorDeclaration */ ), nil } @@ -878,6 +880,7 @@ func (b *Builder) buildSubquery( false, /* tailCall */ false, /* procedure */ nil, /* exceptionHandler */ + nil, /* cursorDeclaration */ ), nil } @@ -994,6 +997,7 @@ func (b *Builder) buildUDF(ctx *buildScalarCtx, scalar opt.ScalarExpr) (tree.Typ false, /* tailCall */ false, /* procedure */ nil, /* exceptionHandler */ + nil, /* cursorDeclaration */ ) } } @@ -1010,6 +1014,7 @@ func (b *Builder) buildUDF(ctx *buildScalarCtx, scalar opt.ScalarExpr) (tree.Typ udf.TailCall, false, /* procedure */ exceptionHandler, + udf.Def.CursorDeclaration, ), nil } diff --git a/pkg/sql/opt/memo/expr.go b/pkg/sql/opt/memo/expr.go index 042bea089ee6..cabde2c6c0f8 100644 --- a/pkg/sql/opt/memo/expr.go +++ b/pkg/sql/opt/memo/expr.go @@ -725,6 +725,13 @@ type UDFDefinition struct { // ExceptionBlock contains information needed for exception-handling when the // body of this routine returns an error. It can be unset. ExceptionBlock *ExceptionBlock + + // CursorDeclaration contains the information needed to open a SQL cursor with + // the result of the *first* body statement. If it is set, there will be at + // least two body statements - one to open the cursor, and one to evaluate the + // result of the routine. This invariant is enforced when the PLpgSQL routine + // is built. CursorDeclaration may be unset. + CursorDeclaration *tree.RoutineOpenCursor } // ExceptionBlock contains the information needed to match and handle errors in diff --git a/pkg/sql/opt/memo/expr_format.go b/pkg/sql/opt/memo/expr_format.go index 1cf4b6030dbd..7d38476e9076 100644 --- a/pkg/sql/opt/memo/expr_format.go +++ b/pkg/sql/opt/memo/expr_format.go @@ -957,6 +957,12 @@ func (f *ExprFmtCtx) formatScalarWithLabel( } n = tp.Child("body") for i := range udf.Def.Body { + if i == 0 && udf.Def.CursorDeclaration != nil { + // The first statement is opening a cursor. + cur := n.Child("open-cursor") + f.formatExpr(udf.Def.Body[i], cur) + continue + } f.formatExpr(udf.Def.Body[i], n) } delete(f.seenUDFs, udf.Def) diff --git a/pkg/sql/opt/memo/interner.go b/pkg/sql/opt/memo/interner.go index cb5f7f847ba7..f71d64d3d8b9 100644 --- a/pkg/sql/opt/memo/interner.go +++ b/pkg/sql/opt/memo/interner.go @@ -1231,6 +1231,33 @@ func (h *hasher) IsUDFDefinitionEqual(l, r *UDFDefinition) bool { return false } } + if l.ExceptionBlock != nil { + if r.ExceptionBlock == nil || len(l.ExceptionBlock.Actions) != len(r.ExceptionBlock.Actions) { + return false + } + for i := range l.ExceptionBlock.Actions { + if !h.IsUDFDefinitionEqual(l.ExceptionBlock.Actions[i], r.ExceptionBlock.Actions[i]) { + return false + } + if l.ExceptionBlock.Codes[i] != r.ExceptionBlock.Codes[i] { + return false + } + } + } else if r.ExceptionBlock != nil { + return false + } + if l.CursorDeclaration != nil { + if r.CursorDeclaration == nil { + return false + } + if l.CursorDeclaration.NameArgIdx != r.CursorDeclaration.NameArgIdx || + l.CursorDeclaration.Scroll != r.CursorDeclaration.Scroll || + l.CursorDeclaration.CursorSQL != r.CursorDeclaration.CursorSQL { + return false + } + } else if r.CursorDeclaration != nil { + return false + } return h.IsColListEqual(l.Params, r.Params) && l.IsRecursive == r.IsRecursive } diff --git a/pkg/sql/opt/norm/inline_funcs.go b/pkg/sql/opt/norm/inline_funcs.go index a7cd103abd91..291a47525aa8 100644 --- a/pkg/sql/opt/norm/inline_funcs.go +++ b/pkg/sql/opt/norm/inline_funcs.go @@ -413,7 +413,8 @@ func (c *CustomFuncs) InlineConstVar(f memo.FiltersExpr) memo.FiltersExpr { // 4. Its arguments are only Variable or Const expressions. // 5. It is not a record-returning function. // 6. It does not recursively call itself. -// 7. It does not have an exception-handling block. +// 7. It does not open a cursor. +// 8. It does not have an exception-handling block. // // UDFs with mutations (INSERT, UPDATE, UPSERT, DELETE) cannot be inlined, but // we do not need an explicit check for this because immutable UDFs cannot @@ -448,7 +449,7 @@ func (c *CustomFuncs) IsInlinableUDF(args memo.ScalarListExpr, udfp *memo.UDFCal } if udfp.Def.IsRecursive || udfp.Def.Volatility == volatility.Volatile || len(udfp.Def.Body) != 1 || udfp.Def.SetReturning || udfp.Def.MultiColDataSource || - udfp.Def.ExceptionBlock != nil { + udfp.Def.CursorDeclaration != nil || udfp.Def.ExceptionBlock != nil { return false } if !args.IsConstantsAndPlaceholdersAndVariables() { diff --git a/pkg/sql/opt/optbuilder/plpgsql.go b/pkg/sql/opt/optbuilder/plpgsql.go index 7a2f18ee2849..9c66ebaf0866 100644 --- a/pkg/sql/opt/optbuilder/plpgsql.go +++ b/pkg/sql/opt/optbuilder/plpgsql.go @@ -429,6 +429,7 @@ func (b *plpgsqlBuilder) buildPLpgSQLStatements(stmts []ast.Statement, s *scope) // The synchronous notice sending behavior is implemented in the // crdb_internal.plpgsql_raise builtin function. con := b.makeContinuation("_stmt_raise") + con.def.Volatility = volatility.Volatile b.appendBodyStmt(&con, b.buildPLpgSQLRaise(con.s, b.getRaiseArgs(con.s, t))) b.appendPlpgSQLStmts(&con, stmts[i+1:]) return b.callContinuation(&con, s) @@ -506,6 +507,57 @@ func (b *plpgsqlBuilder) buildPLpgSQLStatements(stmts []ast.Statement, s *scope) b.appendBodyStmt(&execCon, intoScope) return b.callContinuation(&execCon, s) + case *ast.Open: + // OPEN statements are used to create a CURSOR for the current session. + // This is handled by calling the plpgsql_open_cursor internal builtin + // function in a separate body statement that returns no results, similar + // to the RAISE implementation. + if b.exceptionBlock != nil { + panic(unimplemented.New("open with exception block", + "opening a cursor in a routine with an exception block is not yet supported", + )) + } + if t.Scroll == tree.Scroll { + panic(unimplemented.NewWithIssue(77102, "DECLARE SCROLL CURSOR")) + } + if t.Query == nil { + panic(unimplemented.New("bound cursor", "opening a bound cursor is not yet supported")) + } + if _, ok := t.Query.(*tree.Select); !ok { + panic(pgerror.Newf( + pgcode.InvalidCursorDefinition, "cannot open %s query as cursor", + t.Query.StatementTag(), + )) + } + openCon := b.makeContinuation("_stmt_open") + openCon.def.Volatility = volatility.Volatile + fmtCtx := b.ob.evalCtx.FmtCtx(tree.FmtSimple) + fmtCtx.FormatNode(t.Query) + _, source, _, err := openCon.s.FindSourceProvidingColumn(b.ob.ctx, t.CurVar) + if err != nil { + if pgerror.GetPGCode(err) == pgcode.UndefinedColumn { + panic(pgerror.Newf(pgcode.Syntax, "\"%s\" is not a known variable", t.CurVar)) + } + panic(err) + } + // Initialize the routine with the information needed to pipe the first + // body statement into a cursor. + openCon.def.CursorDeclaration = &tree.RoutineOpenCursor{ + NameArgIdx: source.(*scopeColumn).getParamOrd(), + Scroll: t.Scroll, + CursorSQL: fmtCtx.CloseAndGetString(), + } + openScope := b.ob.buildStmtAtRootWithScope(t.Query, nil /* desiredTypes */, openCon.s) + if openScope.expr.Relational().CanMutate { + // Cursors with mutations are invalid. + panic(pgerror.Newf(pgcode.FeatureNotSupported, + "DECLARE CURSOR must not contain data-modifying statements in WITH", + )) + } + b.appendBodyStmt(&openCon, openScope) + b.appendPlpgSQLStmts(&openCon, stmts[i+1:]) + return b.callContinuation(&openCon, s) + default: panic(unimplemented.New( "unimplemented PL/pgSQL statement", @@ -562,6 +614,7 @@ func (b *plpgsqlBuilder) buildPLpgSQLRaise(inScope *scope, args memo.ScalarListE ) raiseColName := scopeColName("").WithMetadataName(b.makeIdentifier("stmt_raise")) raiseScope := inScope.push() + b.ensureScopeHasExpr(raiseScope) b.ob.synthesizeColumn(raiseScope, raiseColName, types.Int, nil /* expr */, raiseCall) b.ob.constructProjectForScope(inScope, raiseScope) return raiseScope @@ -829,6 +882,7 @@ func (b *plpgsqlBuilder) buildEndOfFunctionRaise(inScope *scope) *scope { makeConstStr(pgcode.RoutineExceptionFunctionExecutedNoReturnStatement.String()), /* code */ } con := b.makeContinuation("_end_of_function") + con.def.Volatility = volatility.Volatile b.appendBodyStmt(&con, b.buildPLpgSQLRaise(con.s, args)) // Build a dummy statement that returns NULL. It won't be executed, but // ensures that the continuation routine's return type is correct. @@ -919,11 +973,7 @@ func (b *plpgsqlBuilder) callContinuation(con *continuation, s *scope) *scope { if err != nil { panic(err) } - if source != nil { - args = append(args, b.ob.factory.ConstructVariable(source.(*scopeColumn).id)) - } else { - args = append(args, b.ob.factory.ConstructNull(typ)) - } + args = append(args, b.ob.factory.ConstructVariable(source.(*scopeColumn).id)) } for _, dec := range b.decls { addArg(dec.Var, b.varTypes[dec.Var]) diff --git a/pkg/sql/opt/optbuilder/scope_column.go b/pkg/sql/opt/optbuilder/scope_column.go index 4ca6686953a3..1074acd3c1e3 100644 --- a/pkg/sql/opt/optbuilder/scope_column.go +++ b/pkg/sql/opt/optbuilder/scope_column.go @@ -135,6 +135,15 @@ func (c *scopeColumn) setParamOrd(ord int) { c.paramOrd = funcParamOrd(ord + 1) } +// getParamOrd retrieves the 0-based ordinal from the column's 1-based function +// parameter ordinal. Panics if the function ordinal is unset. +func (c *scopeColumn) getParamOrd() int { + if c.paramOrd < 1 { + panic(errors.AssertionFailedf("expected non-negative argument ordinal")) + } + return int(c.paramOrd) - 1 +} + // funcParamReferencedBy returns true if the scopeColumn is a function parameter // column that can be referenced by the given placeholder. func (c *scopeColumn) funcParamReferencedBy(idx tree.PlaceholderIdx) bool { diff --git a/pkg/sql/opt/optbuilder/testdata/udf_plpgsql b/pkg/sql/opt/optbuilder/testdata/udf_plpgsql index eacc4350caeb..4c21e18981c3 100644 --- a/pkg/sql/opt/optbuilder/testdata/udf_plpgsql +++ b/pkg/sql/opt/optbuilder/testdata/udf_plpgsql @@ -4603,3 +4603,222 @@ project │ └── projections │ └── const: 0 [as=stmt_return_9:20] └── const: 1 + +# Testing OPEN statement. +exec-ddl +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +---- + +build format=show-scalars +SELECT f(); +---- +project + ├── columns: f:6 + ├── values + │ └── tuple + └── projections + └── udf: f [as=f:6] + └── body + └── limit + ├── columns: "_stmt_open_1":5 + ├── project + │ ├── columns: "_stmt_open_1":5 + │ ├── project + │ │ ├── columns: curs:1!null + │ │ ├── values + │ │ │ └── tuple + │ │ └── projections + │ │ └── const: 'foo' [as=curs:1] + │ └── projections + │ └── udf: _stmt_open_1 [as="_stmt_open_1":5] + │ ├── args + │ │ └── variable: curs:1 + │ ├── params: curs:2 + │ └── body + │ ├── open-cursor + │ │ └── project + │ │ ├── columns: "?column?":3!null + │ │ ├── values + │ │ │ └── tuple + │ │ └── projections + │ │ └── const: 1 [as="?column?":3] + │ └── project + │ ├── columns: stmt_return_2:4!null + │ ├── values + │ │ └── tuple + │ └── projections + │ └── const: 0 [as=stmt_return_2:4] + └── const: 1 + +exec-ddl +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + i INT := 3; + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT * FROM xy WHERE x = i; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +---- + +build format=show-scalars +SELECT f(); +---- +project + ├── columns: f:12 + ├── values + │ └── tuple + └── projections + └── udf: f [as=f:12] + └── body + └── limit + ├── columns: "_stmt_open_1":11 + ├── project + │ ├── columns: "_stmt_open_1":11 + │ ├── project + │ │ ├── columns: curs:2!null i:1!null + │ │ ├── project + │ │ │ ├── columns: i:1!null + │ │ │ ├── values + │ │ │ │ └── tuple + │ │ │ └── projections + │ │ │ └── const: 3 [as=i:1] + │ │ └── projections + │ │ └── const: 'foo' [as=curs:2] + │ └── projections + │ └── udf: _stmt_open_1 [as="_stmt_open_1":11] + │ ├── args + │ │ ├── variable: i:1 + │ │ └── variable: curs:2 + │ ├── params: i:3 curs:4 + │ └── body + │ ├── open-cursor + │ │ └── project + │ │ ├── columns: x:5!null y:6 + │ │ └── select + │ │ ├── columns: x:5!null y:6 rowid:7!null crdb_internal_mvcc_timestamp:8 tableoid:9 + │ │ ├── scan xy + │ │ │ └── columns: x:5 y:6 rowid:7!null crdb_internal_mvcc_timestamp:8 tableoid:9 + │ │ └── filters + │ │ └── eq + │ │ ├── variable: x:5 + │ │ └── variable: i:3 + │ └── project + │ ├── columns: stmt_return_2:10!null + │ ├── values + │ │ └── tuple + │ └── projections + │ └── const: 0 [as=stmt_return_2:10] + └── const: 1 + +exec-ddl +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + curs2 STRING := 'bar'; + curs3 STRING := 'baz'; + BEGIN + OPEN curs FOR SELECT 1; + OPEN curs2 FOR SELECT 2; + OPEN curs3 FOR SELECT 3; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +---- + +build format=show-scalars +SELECT f(); +---- +project + ├── columns: f:20 + ├── values + │ └── tuple + └── projections + └── udf: f [as=f:20] + └── body + └── limit + ├── columns: "_stmt_open_1":19 + ├── project + │ ├── columns: "_stmt_open_1":19 + │ ├── project + │ │ ├── columns: curs3:3!null curs:1!null curs2:2!null + │ │ ├── project + │ │ │ ├── columns: curs2:2!null curs:1!null + │ │ │ ├── project + │ │ │ │ ├── columns: curs:1!null + │ │ │ │ ├── values + │ │ │ │ │ └── tuple + │ │ │ │ └── projections + │ │ │ │ └── const: 'foo' [as=curs:1] + │ │ │ └── projections + │ │ │ └── const: 'bar' [as=curs2:2] + │ │ └── projections + │ │ └── const: 'baz' [as=curs3:3] + │ └── projections + │ └── udf: _stmt_open_1 [as="_stmt_open_1":19] + │ ├── args + │ │ ├── variable: curs:1 + │ │ ├── variable: curs2:2 + │ │ └── variable: curs3:3 + │ ├── params: curs:4 curs2:5 curs3:6 + │ └── body + │ ├── open-cursor + │ │ └── project + │ │ ├── columns: "?column?":7!null + │ │ ├── values + │ │ │ └── tuple + │ │ └── projections + │ │ └── const: 1 [as="?column?":7] + │ └── project + │ ├── columns: "_stmt_open_2":18 + │ ├── values + │ │ └── tuple + │ └── projections + │ └── udf: _stmt_open_2 [as="_stmt_open_2":18] + │ ├── args + │ │ ├── variable: curs:4 + │ │ ├── variable: curs2:5 + │ │ └── variable: curs3:6 + │ ├── params: curs:8 curs2:9 curs3:10 + │ └── body + │ ├── open-cursor + │ │ └── project + │ │ ├── columns: "?column?":11!null + │ │ ├── values + │ │ │ └── tuple + │ │ └── projections + │ │ └── const: 2 [as="?column?":11] + │ └── project + │ ├── columns: "_stmt_open_3":17 + │ ├── values + │ │ └── tuple + │ └── projections + │ └── udf: _stmt_open_3 [as="_stmt_open_3":17] + │ ├── args + │ │ ├── variable: curs:8 + │ │ ├── variable: curs2:9 + │ │ └── variable: curs3:10 + │ ├── params: curs:12 curs2:13 curs3:14 + │ └── body + │ ├── open-cursor + │ │ └── project + │ │ ├── columns: "?column?":15!null + │ │ ├── values + │ │ │ └── tuple + │ │ └── projections + │ │ └── const: 3 [as="?column?":15] + │ └── project + │ ├── columns: stmt_return_4:16!null + │ ├── values + │ │ └── tuple + │ └── projections + │ └── const: 0 [as=stmt_return_4:16] + └── const: 1 diff --git a/pkg/sql/routine.go b/pkg/sql/routine.go index fc5a7c2f5e49..e8602767195f 100644 --- a/pkg/sql/routine.go +++ b/pkg/sql/routine.go @@ -15,6 +15,8 @@ import ( "strconv" "github.com/cockroachdb/cockroach/pkg/kv" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/colinfo" + "github.com/cockroachdb/cockroach/pkg/sql/isql" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/plpgsql" @@ -22,6 +24,8 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/util" + "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/tracing" "github.com/cockroachdb/errors" ) @@ -226,23 +230,30 @@ func (g *routineGenerator) startInternal(ctx context.Context, txn *kv.Txn) (err stmtIdx := 0 ef := newExecFactory(ctx, g.p) rrw := NewRowResultWriter(&g.rch) + var cursorHelper *plpgsqlCursorHelper err = g.expr.ForEachPlan(ctx, ef, g.args, func(plan tree.RoutinePlan, isFinalPlan bool) error { stmtIdx++ opName := "udf-stmt-" + g.expr.Name + "-" + strconv.Itoa(stmtIdx) ctx, sp := tracing.ChildSpan(ctx, opName) defer sp.Finish() - // If this is the last statement and it is not a procedure, use the - // rowResultWriter created above. Otherwise, use a rowResultWriter that - // drops all rows added to it. - // - // We can use a droppingResultWriter for all statements in a procedure - // because we do not yet allow OUT or INOUT parameters, so a procedure - // never returns values. var w rowResultWriter + openCursor := stmtIdx == 1 && g.expr.CursorDeclaration != nil if isFinalPlan && !g.expr.Procedure { + // The result of this statement is the routine's output. This is never the + // case for a procedure, which does not output any rows (since we do not + // yet support OUT or INOUT parameters). w = rrw + } else if openCursor { + // The result of the first statement will be used to open a SQL cursor. + cursorHelper, err = g.newCursorHelper(ctx, plan.(*planComponents)) + if err != nil { + return err + } + w = NewRowResultWriter(&cursorHelper.container) } else { + // The result of this statement is not needed. Use a rowResultWriter that + // drops all rows added to it. w = &droppingResultWriter{} } @@ -259,9 +270,15 @@ func (g *routineGenerator) startInternal(ctx context.Context, txn *kv.Txn) (err if err != nil { return err } + if openCursor { + return cursorHelper.createCursor(g.p) + } return nil }) if err != nil { + if cursorHelper != nil { + err = errors.CombineErrors(err, cursorHelper.Close()) + } return g.handleException(ctx, err) } @@ -359,3 +376,109 @@ func (d *droppingResultWriter) SetError(err error) { func (d *droppingResultWriter) Err() error { return d.err } + +func (g *routineGenerator) newCursorHelper( + ctx context.Context, plan *planComponents, +) (*plpgsqlCursorHelper, error) { + open := g.expr.CursorDeclaration + if open.NameArgIdx < 0 || open.NameArgIdx >= len(g.args) { + panic(errors.AssertionFailedf("unexpected name argument index: %d", open.NameArgIdx)) + } + if g.args[open.NameArgIdx] == tree.DNull { + return nil, unimplemented.New("unnamed cursor", + "opening an unnamed cursor is not yet supported", + ) + } + planCols := plan.main.planColumns() + cursorHelper := &plpgsqlCursorHelper{ + ctx: ctx, + cursorName: tree.Name(tree.MustBeDString(g.args[open.NameArgIdx])), + resultCols: make(colinfo.ResultColumns, len(planCols)), + } + copy(cursorHelper.resultCols, planCols) + cursorHelper.container.Init( + ctx, + getTypesFromResultColumns(planCols), + g.p.ExtendedEvalContextCopy(), + "routine_open_cursor", /* opName */ + ) + return cursorHelper, nil +} + +// plpgsqlCursorHelper wraps a row container in order to feed the results of +// executing a SQL statement to a SQL cursor. Note that the SQL statement is not +// lazily executed; its entire result is written to the container. +// TODO(drewk): while the row container can spill to disk, we should default to +// lazy execution for cursors for performance reasons. +type plpgsqlCursorHelper struct { + ctx context.Context + cursorName tree.Name + cursorSql string + + // Fields related to implementing the isql.Rows interface. + container rowContainerHelper + iter *rowContainerIterator + resultCols colinfo.ResultColumns + lastRow tree.Datums + lastErr error + rowsAffected int +} + +func (h *plpgsqlCursorHelper) createCursor(p *planner) error { + h.iter = newRowContainerIterator(h.ctx, h.container) + cursor := &sqlCursor{ + Rows: h, + readSeqNum: p.txn.GetReadSeqNum(), + txn: p.txn, + statement: h.cursorSql, + created: timeutil.Now(), + } + if err := p.checkIfCursorExists(h.cursorName); err != nil { + return err + } + return p.sqlCursors.addCursor(h.cursorName, cursor) +} + +var _ isql.Rows = &plpgsqlCursorHelper{} + +// Next implements the isql.Rows interface. +func (h *plpgsqlCursorHelper) Next(_ context.Context) (bool, error) { + h.lastRow, h.lastErr = h.iter.Next() + if h.lastErr != nil { + return false, h.lastErr + } + if h.lastRow != nil { + h.rowsAffected++ + } + return h.lastRow != nil, nil +} + +// Cur implements the isql.Rows interface. +func (h *plpgsqlCursorHelper) Cur() tree.Datums { + return h.lastRow +} + +// RowsAffected implements the isql.Rows interface. +func (h *plpgsqlCursorHelper) RowsAffected() int { + return h.rowsAffected +} + +// Close implements the isql.Rows interface. +func (h *plpgsqlCursorHelper) Close() error { + if h.iter != nil { + h.iter.Close() + h.iter = nil + } + h.container.Close(h.ctx) + return h.lastErr +} + +// Types implements the isql.Rows interface. +func (h *plpgsqlCursorHelper) Types() colinfo.ResultColumns { + return h.resultCols +} + +// HasResults implements the isql.Rows interface. +func (h *plpgsqlCursorHelper) HasResults() bool { + return h.lastRow != nil +} diff --git a/pkg/sql/sem/tree/routine.go b/pkg/sql/sem/tree/routine.go index 83d1a1c6ebb3..2e842cabdff7 100644 --- a/pkg/sql/sem/tree/routine.go +++ b/pkg/sql/sem/tree/routine.go @@ -126,6 +126,10 @@ type RoutineExpr struct { // ExceptionHandler holds the information needed to handle errors if an // exception block was defined. ExceptionHandler *RoutineExceptionHandler + + // CursorDeclaration contains the information needed to open a SQL cursor with + // the result of the *first* body statement. It may be unset. + CursorDeclaration *RoutineOpenCursor } // NewTypedRoutineExpr returns a new RoutineExpr that is well-typed. @@ -141,6 +145,7 @@ func NewTypedRoutineExpr( tailCall bool, procedure bool, exceptionHandler *RoutineExceptionHandler, + cursorDeclaration *RoutineOpenCursor, ) *RoutineExpr { return &RoutineExpr{ Args: args, @@ -154,6 +159,7 @@ func NewTypedRoutineExpr( TailCall: tailCall, Procedure: procedure, ExceptionHandler: exceptionHandler, + CursorDeclaration: cursorDeclaration, } } @@ -191,3 +197,19 @@ type RoutineExceptionHandler struct { // Actions contains a routine to handle each error code. Actions []*RoutineExpr } + +// RoutineOpenCursor stores the information needed to correctly open a cursor +// with the output of a routine. +type RoutineOpenCursor struct { + // NameArgIdx is the index of the routine argument that contains the name of + // the cursor that will be created. + NameArgIdx int + + // Scroll is the scroll option for the cursor, if one was specified. The other + // cursor options are not valid in PLpgSQL. + Scroll CursorScrollOption + + // CursorSQL is a formatted string used to associate the original SQL + // statement with the cursor. + CursorSQL string +} diff --git a/pkg/sql/sql_cursor.go b/pkg/sql/sql_cursor.go index 1a722d6ed03f..5c5166c86aae 100644 --- a/pkg/sql/sql_cursor.go +++ b/pkg/sql/sql_cursor.go @@ -59,12 +59,8 @@ func (p *planner) DeclareCursor(ctx context.Context, s *tree.DeclareCursor) (pla sd.StmtTimeout = 0 } ie := p.ExecCfg().InternalDB.NewInternalExecutor(sd) - if cursor := p.sqlCursors.getCursor(s.Name); cursor != nil { - return nil, pgerror.Newf(pgcode.DuplicateCursor, "cursor %q already exists", s.Name) - } - - if p.extendedEvalCtx.PreparedStatementState.HasPortal(string(s.Name)) { - return nil, pgerror.Newf(pgcode.DuplicateCursor, "cursor %q already exists as portal", s.Name) + if err := p.checkIfCursorExists(s.Name); err != nil { + return nil, err } // Try to plan the cursor query to make sure that it's valid. @@ -122,6 +118,18 @@ func (p *planner) DeclareCursor(ctx context.Context, s *tree.DeclareCursor) (pla }, nil } +// checkIfCursorExists checks whether a cursor or portal with the given name +// already exists, and returns an error if one does. +func (p *planner) checkIfCursorExists(name tree.Name) error { + if cursor := p.sqlCursors.getCursor(name); cursor != nil { + return pgerror.Newf(pgcode.DuplicateCursor, "cursor %q already exists", name) + } + if p.extendedEvalCtx.PreparedStatementState.HasPortal(string(name)) { + return pgerror.Newf(pgcode.DuplicateCursor, "cursor %q already exists as portal", name) + } + return nil +} + var errBackwardScan = pgerror.Newf(pgcode.ObjectNotInPrerequisiteState, "cursor can only scan forward") // FetchCursor implements the FETCH and MOVE statements. From 6fa48708b42d41cf592edbdfe77ccc1cb08f60d5 Mon Sep 17 00:00:00 2001 From: Arul Ajmani Date: Wed, 27 Sep 2023 17:39:26 -0500 Subject: [PATCH 3/3] kvserver: latching changes for replicated shared locks Two locking requests from the same transaction that are trying to acquire replicated shared locks need to be isolated from one another. They don't need to be isolated against shared locking requests from other transactions and unreplicated shared lock attempts from the same transaction. To achieve these semantics, we introduce a per-transaction range local key that all replicated shared locking requests declare non-MVCC write latches over. Closes #109668 Release note: None --- pkg/keys/constants.go | 5 ++ pkg/keys/doc.go | 13 ++-- pkg/keys/keys.go | 19 +++++ pkg/keys/printer.go | 20 +++++ pkg/keys/printer_test.go | 1 + pkg/kv/kvserver/batcheval/declare.go | 14 +++- .../concurrency/datadriven_util_test.go | 13 ++++ .../concurrency_manager/shared_locks_latches | 75 +++++++++++++++++++ 8 files changed, 152 insertions(+), 8 deletions(-) diff --git a/pkg/keys/constants.go b/pkg/keys/constants.go index 25f68e360ecc..c8846a3e6d60 100644 --- a/pkg/keys/constants.go +++ b/pkg/keys/constants.go @@ -72,6 +72,11 @@ var ( // AbortSpan protects a transaction from re-reading its own intents // after it's been aborted. LocalAbortSpanSuffix = []byte("abc-") + // LocalReplicatedSharedLocksTransactionLatchingKeySuffix specifies the key + // suffix ("rsl" = replicated shared locks) for all replicated shared lock + // attempts, per transaction. The detail about the transaction is the + // transaction id. + LocalReplicatedSharedLocksTransactionLatchingKeySuffix = roachpb.RKey("rsl-") // localRangeFrozenStatusSuffix is DEPRECATED and remains to prevent reuse. localRangeFrozenStatusSuffix = []byte("fzn-") // LocalRangeGCThresholdSuffix is the suffix for the GC threshold. It keeps diff --git a/pkg/keys/doc.go b/pkg/keys/doc.go index fe42458b864e..8138eb7bf071 100644 --- a/pkg/keys/doc.go +++ b/pkg/keys/doc.go @@ -181,12 +181,13 @@ var _ = [...]interface{}{ // range as a whole. Though they are replicated, they are unaddressable. // Typical examples are MVCC stats and the abort span. They all share // `LocalRangeIDPrefix` and `LocalRangeIDReplicatedInfix`. - AbortSpanKey, // "abc-" - RangeGCThresholdKey, // "lgc-" - RangeAppliedStateKey, // "rask" - RangeLeaseKey, // "rll-" - RangePriorReadSummaryKey, // "rprs" - RangeVersionKey, // "rver" + AbortSpanKey, // "abc-" + ReplicatedSharedLocksTransactionLatchingKey, // "rsl-" + RangeGCThresholdKey, // "lgc-" + RangeAppliedStateKey, // "rask" + RangeLeaseKey, // "rll-" + RangePriorReadSummaryKey, // "rprs" + RangeVersionKey, // "rver" // 2. Unreplicated range-ID local keys: These contain metadata that // pertain to just one replica of a range. They are unreplicated and diff --git a/pkg/keys/keys.go b/pkg/keys/keys.go index 471bdb95755d..22db176dbb62 100644 --- a/pkg/keys/keys.go +++ b/pkg/keys/keys.go @@ -251,6 +251,16 @@ func AbortSpanKey(rangeID roachpb.RangeID, txnID uuid.UUID) roachpb.Key { return MakeRangeIDPrefixBuf(rangeID).AbortSpanKey(txnID) } +// ReplicatedSharedLocksTransactionLatchingKey returns a range-local key, based +// on the provided range ID and transaction ID, that all replicated shared +// locking requests from the specified transaction should use to serialize on +// latches. +func ReplicatedSharedLocksTransactionLatchingKey( + rangeID roachpb.RangeID, txnID uuid.UUID, +) roachpb.Key { + return MakeRangeIDPrefixBuf(rangeID).ReplicatedSharedLocksTransactionLatchingKey(txnID) +} + // DecodeAbortSpanKey decodes the provided AbortSpan entry, // returning the transaction ID. func DecodeAbortSpanKey(key roachpb.Key, dest []byte) (uuid.UUID, error) { @@ -1066,6 +1076,15 @@ func (b RangeIDPrefixBuf) AbortSpanKey(txnID uuid.UUID) roachpb.Key { return encoding.EncodeBytesAscending(key, txnID.GetBytes()) } +// ReplicatedSharedLocksTransactionLatchingKey returns a range-local key, by +// range ID, for a key on which all replicated shared locking requests from a +// specific transaction should serialize on latches. The per-transaction bit is +// achieved by encoding the supplied transaction ID into the key. +func (b RangeIDPrefixBuf) ReplicatedSharedLocksTransactionLatchingKey(txnID uuid.UUID) roachpb.Key { + key := append(b.replicatedPrefix(), LocalReplicatedSharedLocksTransactionLatchingKeySuffix...) + return encoding.EncodeBytesAscending(key, txnID.GetBytes()) +} + // RangeAppliedStateKey returns a system-local key for the range applied state key. // See comment on RangeAppliedStateKey function. func (b RangeIDPrefixBuf) RangeAppliedStateKey() roachpb.Key { diff --git a/pkg/keys/printer.go b/pkg/keys/printer.go index 52a19a1863b7..b4e016b256e7 100644 --- a/pkg/keys/printer.go +++ b/pkg/keys/printer.go @@ -107,6 +107,10 @@ var ( psFunc func(rangeID roachpb.RangeID, input string) (string, roachpb.Key) }{ {name: "AbortSpan", suffix: LocalAbortSpanSuffix, ppFunc: abortSpanKeyPrint, psFunc: abortSpanKeyParse}, + {name: "ReplicatedSharedLocksTransactionLatch", + suffix: LocalReplicatedSharedLocksTransactionLatchingKeySuffix, + ppFunc: replicatedSharedLocksTransactionLatchingKeyPrint, + }, {name: "RangeTombstone", suffix: LocalRangeTombstoneSuffix}, {name: "RaftHardState", suffix: LocalRaftHardStateSuffix}, {name: "RangeAppliedState", suffix: LocalRangeAppliedStateSuffix}, @@ -567,6 +571,22 @@ func abortSpanKeyPrint(buf *redact.StringBuilder, key roachpb.Key) { buf.Printf("/%q", txnID) } +func replicatedSharedLocksTransactionLatchingKeyPrint(buf *redact.StringBuilder, key roachpb.Key) { + _, id, err := encoding.DecodeBytesAscending([]byte(key), nil) + if err != nil { + buf.Printf("/%q/err:%v", key, err) + return + } + + txnID, err := uuid.FromBytes(id) + if err != nil { + buf.Printf("/%q/err:%v", key, err) + return + } + + buf.Printf("/%q", txnID) +} + func print(buf *redact.StringBuilder, _ []encoding.Direction, key roachpb.Key) { buf.Printf("/%q", []byte(key)) } diff --git a/pkg/keys/printer_test.go b/pkg/keys/printer_test.go index d21dbd89a907..1a51519d6365 100644 --- a/pkg/keys/printer_test.go +++ b/pkg/keys/printer_test.go @@ -242,6 +242,7 @@ func TestPrettyPrint(t *testing.T) { {keys.StoreLossOfQuorumRecoveryCleanupActionsKey(), "/Local/Store/lossOfQuorumRecovery/cleanup", revertSupportUnknown}, {keys.AbortSpanKey(roachpb.RangeID(1000001), txnID), fmt.Sprintf(`/Local/RangeID/1000001/r/AbortSpan/%q`, txnID), revertSupportUnknown}, + {keys.ReplicatedSharedLocksTransactionLatchingKey(roachpb.RangeID(1000001), txnID), fmt.Sprintf(`/Local/RangeID/1000001/r/ReplicatedSharedLocksTransactionLatch/%q`, txnID), revertSupportUnknown}, {keys.RangeAppliedStateKey(roachpb.RangeID(1000001)), "/Local/RangeID/1000001/r/RangeAppliedState", revertSupportUnknown}, {keys.RaftTruncatedStateKey(roachpb.RangeID(1000001)), "/Local/RangeID/1000001/u/RaftTruncatedState", revertSupportUnknown}, {keys.RangeLeaseKey(roachpb.RangeID(1000001)), "/Local/RangeID/1000001/r/RangeLease", revertSupportUnknown}, diff --git a/pkg/kv/kvserver/batcheval/declare.go b/pkg/kv/kvserver/batcheval/declare.go index 74951f4d2016..35a5b3a5c292 100644 --- a/pkg/kv/kvserver/batcheval/declare.go +++ b/pkg/kv/kvserver/batcheval/declare.go @@ -50,7 +50,7 @@ func DefaultDeclareKeys( // ensures that the commands are fully isolated from conflicting transactions // when it evaluated. func DefaultDeclareIsolatedKeys( - _ ImmutableRangeState, + rs ImmutableRangeState, header *kvpb.Header, req kvpb.Request, latchSpans *spanset.SpanSet, @@ -92,7 +92,8 @@ func DefaultDeclareIsolatedKeys( // Get the correct lock strength to use for {lock,latch} spans if we're // dealing with locking read requests. if readOnlyReq, ok := req.(kvpb.LockingReadRequest); ok { - str, _ = readOnlyReq.KeyLocking() + var dur lock.Durability + str, dur = readOnlyReq.KeyLocking() switch str { case lock.None: panic(errors.AssertionFailedf("unexpected non-locking read handling")) @@ -109,6 +110,15 @@ func DefaultDeclareIsolatedKeys( // from concurrent writers operating at lower timestamps, a shared-locking // read extends this protection to all timestamps. timestamp = hlc.MaxTimestamp + if dur == lock.Replicated && header.Txn != nil { + // Concurrent replicated shared lock attempts by the same transaction + // need to be isolated from one another. We acquire a write latch on + // a per-transaction local key to achieve this. See + // https://github.com/cockroachdb/cockroach/issues/109668. + latchSpans.AddNonMVCC(spanset.SpanReadWrite, roachpb.Span{ + Key: keys.ReplicatedSharedLocksTransactionLatchingKey(rs.GetRangeID(), header.Txn.ID), + }) + } case lock.Exclusive: // Reads that acquire exclusive locks acquire write latches at the // request's timestamp. This isolates them from all concurrent writes, diff --git a/pkg/kv/kvserver/concurrency/datadriven_util_test.go b/pkg/kv/kvserver/concurrency/datadriven_util_test.go index e2e9ebac959f..b31e38b45e45 100644 --- a/pkg/kv/kvserver/concurrency/datadriven_util_test.go +++ b/pkg/kv/kvserver/concurrency/datadriven_util_test.go @@ -86,6 +86,10 @@ func scanUserPriority(t *testing.T, d *datadriven.TestData) roachpb.UserPriority func scanLockDurability(t *testing.T, d *datadriven.TestData) lock.Durability { var durS string d.ScanArgs(t, "dur", &durS) + return getLockDurability(t, d, durS) +} + +func getLockDurability(t *testing.T, d *datadriven.TestData, durS string) lock.Durability { switch durS { case "r": return lock.Replicated @@ -177,6 +181,13 @@ func scanSingleRequest( } return concurrency.GetStrength(t, d, s) } + maybeGetDur := func() lock.Durability { + s, ok := fields["dur"] + if !ok { + return lock.Unreplicated + } + return getLockDurability(t, d, s) + } switch cmd { case "get": @@ -184,6 +195,7 @@ func scanSingleRequest( r.Sequence = maybeGetSeq() r.Key = roachpb.Key(mustGetField("key")) r.KeyLockingStrength = maybeGetStr() + r.KeyLockingDurability = maybeGetDur() return &r case "scan": @@ -194,6 +206,7 @@ func scanSingleRequest( r.EndKey = roachpb.Key(v) } r.KeyLockingStrength = maybeGetStr() + r.KeyLockingDurability = maybeGetDur() return &r case "put": diff --git a/pkg/kv/kvserver/concurrency/testdata/concurrency_manager/shared_locks_latches b/pkg/kv/kvserver/concurrency/testdata/concurrency_manager/shared_locks_latches index 42b811010096..9781a62be7e0 100644 --- a/pkg/kv/kvserver/concurrency/testdata/concurrency_manager/shared_locks_latches +++ b/pkg/kv/kvserver/concurrency/testdata/concurrency_manager/shared_locks_latches @@ -641,3 +641,78 @@ finish req=req33 finish req=req34 ---- [-] finish req34: finishing request + +# ------------------------------------------------------------------------------ +# Ensure concurrent replicated shared locking requests by the same transaction +# conflict on latches. Also ensure concurrent replicated shared lock attempts +# by different transactions do not. +# ------------------------------------------------------------------------------ + +new-request name=req35 txn=txn2 ts=11,1 + get key=c str=shared dur=r +---- + +sequence req=req35 +---- +[35] sequence req35: sequencing request +[35] sequence req35: acquiring latches +[35] sequence req35: scanning lock table for conflicting locks +[35] sequence req35: sequencing complete, returned guard + +new-request name=req36 txn=txn2 ts=11,1 + scan key=a endkey=f str=shared dur=r +---- + +sequence req=req36 +---- +[36] sequence req36: sequencing request +[36] sequence req36: acquiring latches +[36] sequence req36: waiting to acquire write latch ‹/Local/RangeID/1/r/ReplicatedSharedLocksTransactionLatch/"00000002-0000-0000-0000-000000000000"›@0,0, held by write latch ‹/Local/RangeID/1/r/ReplicatedSharedLocksTransactionLatch/"00000002-0000-0000-0000-000000000000"›@0,0 +[36] sequence req36: blocked on select in spanlatch.(*Manager).waitForSignal + +new-request name=req37 txn=txn1 ts=11,1 + get key=c str=shared dur=r +---- + +sequence req=req37 +---- +[37] sequence req37: sequencing request +[37] sequence req37: acquiring latches +[37] sequence req37: scanning lock table for conflicting locks +[37] sequence req37: sequencing complete, returned guard + + +# Unreplicated shared locking request from txn2. Shouldn't conflict on latches. +new-request name=req38 txn=txn2 ts=11,1 + get key=c str=shared dur=u +---- + +sequence req=req38 +---- +[38] sequence req38: sequencing request +[38] sequence req38: acquiring latches +[38] sequence req38: scanning lock table for conflicting locks +[38] sequence req38: sequencing complete, returned guard + +debug-latch-manager +---- +write count: 3 + read count: 4 + +finish req=req35 +---- +[-] finish req35: finishing request +[36] sequence req36: scanning lock table for conflicting locks +[36] sequence req36: sequencing complete, returned guard + +finish req=req36 +---- +[-] finish req36: finishing request + +finish req=req37 +---- +[-] finish req37: finishing request + +finish req=req38 +---- +[-] finish req38: finishing request