diff --git a/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go b/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go index 027c6485e9d4..af92e4aab661 100644 --- a/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go +++ b/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go @@ -1381,6 +1381,13 @@ func TestTenantLogic_pgoidtype( runLogicTest(t, "pgoidtype") } +func TestTenantLogic_plpgsql_cursor( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "plpgsql_cursor") +} + func TestTenantLogic_poison_after_push( t *testing.T, ) { diff --git a/pkg/keys/constants.go b/pkg/keys/constants.go index 25f68e360ecc..c8846a3e6d60 100644 --- a/pkg/keys/constants.go +++ b/pkg/keys/constants.go @@ -72,6 +72,11 @@ var ( // AbortSpan protects a transaction from re-reading its own intents // after it's been aborted. LocalAbortSpanSuffix = []byte("abc-") + // LocalReplicatedSharedLocksTransactionLatchingKeySuffix specifies the key + // suffix ("rsl" = replicated shared locks) for all replicated shared lock + // attempts, per transaction. The detail about the transaction is the + // transaction id. + LocalReplicatedSharedLocksTransactionLatchingKeySuffix = roachpb.RKey("rsl-") // localRangeFrozenStatusSuffix is DEPRECATED and remains to prevent reuse. localRangeFrozenStatusSuffix = []byte("fzn-") // LocalRangeGCThresholdSuffix is the suffix for the GC threshold. It keeps diff --git a/pkg/keys/doc.go b/pkg/keys/doc.go index fe42458b864e..8138eb7bf071 100644 --- a/pkg/keys/doc.go +++ b/pkg/keys/doc.go @@ -181,12 +181,13 @@ var _ = [...]interface{}{ // range as a whole. Though they are replicated, they are unaddressable. // Typical examples are MVCC stats and the abort span. They all share // `LocalRangeIDPrefix` and `LocalRangeIDReplicatedInfix`. - AbortSpanKey, // "abc-" - RangeGCThresholdKey, // "lgc-" - RangeAppliedStateKey, // "rask" - RangeLeaseKey, // "rll-" - RangePriorReadSummaryKey, // "rprs" - RangeVersionKey, // "rver" + AbortSpanKey, // "abc-" + ReplicatedSharedLocksTransactionLatchingKey, // "rsl-" + RangeGCThresholdKey, // "lgc-" + RangeAppliedStateKey, // "rask" + RangeLeaseKey, // "rll-" + RangePriorReadSummaryKey, // "rprs" + RangeVersionKey, // "rver" // 2. Unreplicated range-ID local keys: These contain metadata that // pertain to just one replica of a range. They are unreplicated and diff --git a/pkg/keys/keys.go b/pkg/keys/keys.go index 471bdb95755d..22db176dbb62 100644 --- a/pkg/keys/keys.go +++ b/pkg/keys/keys.go @@ -251,6 +251,16 @@ func AbortSpanKey(rangeID roachpb.RangeID, txnID uuid.UUID) roachpb.Key { return MakeRangeIDPrefixBuf(rangeID).AbortSpanKey(txnID) } +// ReplicatedSharedLocksTransactionLatchingKey returns a range-local key, based +// on the provided range ID and transaction ID, that all replicated shared +// locking requests from the specified transaction should use to serialize on +// latches. +func ReplicatedSharedLocksTransactionLatchingKey( + rangeID roachpb.RangeID, txnID uuid.UUID, +) roachpb.Key { + return MakeRangeIDPrefixBuf(rangeID).ReplicatedSharedLocksTransactionLatchingKey(txnID) +} + // DecodeAbortSpanKey decodes the provided AbortSpan entry, // returning the transaction ID. func DecodeAbortSpanKey(key roachpb.Key, dest []byte) (uuid.UUID, error) { @@ -1066,6 +1076,15 @@ func (b RangeIDPrefixBuf) AbortSpanKey(txnID uuid.UUID) roachpb.Key { return encoding.EncodeBytesAscending(key, txnID.GetBytes()) } +// ReplicatedSharedLocksTransactionLatchingKey returns a range-local key, by +// range ID, for a key on which all replicated shared locking requests from a +// specific transaction should serialize on latches. The per-transaction bit is +// achieved by encoding the supplied transaction ID into the key. +func (b RangeIDPrefixBuf) ReplicatedSharedLocksTransactionLatchingKey(txnID uuid.UUID) roachpb.Key { + key := append(b.replicatedPrefix(), LocalReplicatedSharedLocksTransactionLatchingKeySuffix...) + return encoding.EncodeBytesAscending(key, txnID.GetBytes()) +} + // RangeAppliedStateKey returns a system-local key for the range applied state key. // See comment on RangeAppliedStateKey function. func (b RangeIDPrefixBuf) RangeAppliedStateKey() roachpb.Key { diff --git a/pkg/keys/printer.go b/pkg/keys/printer.go index 52a19a1863b7..b4e016b256e7 100644 --- a/pkg/keys/printer.go +++ b/pkg/keys/printer.go @@ -107,6 +107,10 @@ var ( psFunc func(rangeID roachpb.RangeID, input string) (string, roachpb.Key) }{ {name: "AbortSpan", suffix: LocalAbortSpanSuffix, ppFunc: abortSpanKeyPrint, psFunc: abortSpanKeyParse}, + {name: "ReplicatedSharedLocksTransactionLatch", + suffix: LocalReplicatedSharedLocksTransactionLatchingKeySuffix, + ppFunc: replicatedSharedLocksTransactionLatchingKeyPrint, + }, {name: "RangeTombstone", suffix: LocalRangeTombstoneSuffix}, {name: "RaftHardState", suffix: LocalRaftHardStateSuffix}, {name: "RangeAppliedState", suffix: LocalRangeAppliedStateSuffix}, @@ -567,6 +571,22 @@ func abortSpanKeyPrint(buf *redact.StringBuilder, key roachpb.Key) { buf.Printf("/%q", txnID) } +func replicatedSharedLocksTransactionLatchingKeyPrint(buf *redact.StringBuilder, key roachpb.Key) { + _, id, err := encoding.DecodeBytesAscending([]byte(key), nil) + if err != nil { + buf.Printf("/%q/err:%v", key, err) + return + } + + txnID, err := uuid.FromBytes(id) + if err != nil { + buf.Printf("/%q/err:%v", key, err) + return + } + + buf.Printf("/%q", txnID) +} + func print(buf *redact.StringBuilder, _ []encoding.Direction, key roachpb.Key) { buf.Printf("/%q", []byte(key)) } diff --git a/pkg/keys/printer_test.go b/pkg/keys/printer_test.go index d21dbd89a907..1a51519d6365 100644 --- a/pkg/keys/printer_test.go +++ b/pkg/keys/printer_test.go @@ -242,6 +242,7 @@ func TestPrettyPrint(t *testing.T) { {keys.StoreLossOfQuorumRecoveryCleanupActionsKey(), "/Local/Store/lossOfQuorumRecovery/cleanup", revertSupportUnknown}, {keys.AbortSpanKey(roachpb.RangeID(1000001), txnID), fmt.Sprintf(`/Local/RangeID/1000001/r/AbortSpan/%q`, txnID), revertSupportUnknown}, + {keys.ReplicatedSharedLocksTransactionLatchingKey(roachpb.RangeID(1000001), txnID), fmt.Sprintf(`/Local/RangeID/1000001/r/ReplicatedSharedLocksTransactionLatch/%q`, txnID), revertSupportUnknown}, {keys.RangeAppliedStateKey(roachpb.RangeID(1000001)), "/Local/RangeID/1000001/r/RangeAppliedState", revertSupportUnknown}, {keys.RaftTruncatedStateKey(roachpb.RangeID(1000001)), "/Local/RangeID/1000001/u/RaftTruncatedState", revertSupportUnknown}, {keys.RangeLeaseKey(roachpb.RangeID(1000001)), "/Local/RangeID/1000001/r/RangeLease", revertSupportUnknown}, diff --git a/pkg/kv/kvserver/batcheval/declare.go b/pkg/kv/kvserver/batcheval/declare.go index 74951f4d2016..35a5b3a5c292 100644 --- a/pkg/kv/kvserver/batcheval/declare.go +++ b/pkg/kv/kvserver/batcheval/declare.go @@ -50,7 +50,7 @@ func DefaultDeclareKeys( // ensures that the commands are fully isolated from conflicting transactions // when it evaluated. func DefaultDeclareIsolatedKeys( - _ ImmutableRangeState, + rs ImmutableRangeState, header *kvpb.Header, req kvpb.Request, latchSpans *spanset.SpanSet, @@ -92,7 +92,8 @@ func DefaultDeclareIsolatedKeys( // Get the correct lock strength to use for {lock,latch} spans if we're // dealing with locking read requests. if readOnlyReq, ok := req.(kvpb.LockingReadRequest); ok { - str, _ = readOnlyReq.KeyLocking() + var dur lock.Durability + str, dur = readOnlyReq.KeyLocking() switch str { case lock.None: panic(errors.AssertionFailedf("unexpected non-locking read handling")) @@ -109,6 +110,15 @@ func DefaultDeclareIsolatedKeys( // from concurrent writers operating at lower timestamps, a shared-locking // read extends this protection to all timestamps. timestamp = hlc.MaxTimestamp + if dur == lock.Replicated && header.Txn != nil { + // Concurrent replicated shared lock attempts by the same transaction + // need to be isolated from one another. We acquire a write latch on + // a per-transaction local key to achieve this. See + // https://github.com/cockroachdb/cockroach/issues/109668. + latchSpans.AddNonMVCC(spanset.SpanReadWrite, roachpb.Span{ + Key: keys.ReplicatedSharedLocksTransactionLatchingKey(rs.GetRangeID(), header.Txn.ID), + }) + } case lock.Exclusive: // Reads that acquire exclusive locks acquire write latches at the // request's timestamp. This isolates them from all concurrent writes, diff --git a/pkg/kv/kvserver/concurrency/datadriven_util_test.go b/pkg/kv/kvserver/concurrency/datadriven_util_test.go index e2e9ebac959f..b31e38b45e45 100644 --- a/pkg/kv/kvserver/concurrency/datadriven_util_test.go +++ b/pkg/kv/kvserver/concurrency/datadriven_util_test.go @@ -86,6 +86,10 @@ func scanUserPriority(t *testing.T, d *datadriven.TestData) roachpb.UserPriority func scanLockDurability(t *testing.T, d *datadriven.TestData) lock.Durability { var durS string d.ScanArgs(t, "dur", &durS) + return getLockDurability(t, d, durS) +} + +func getLockDurability(t *testing.T, d *datadriven.TestData, durS string) lock.Durability { switch durS { case "r": return lock.Replicated @@ -177,6 +181,13 @@ func scanSingleRequest( } return concurrency.GetStrength(t, d, s) } + maybeGetDur := func() lock.Durability { + s, ok := fields["dur"] + if !ok { + return lock.Unreplicated + } + return getLockDurability(t, d, s) + } switch cmd { case "get": @@ -184,6 +195,7 @@ func scanSingleRequest( r.Sequence = maybeGetSeq() r.Key = roachpb.Key(mustGetField("key")) r.KeyLockingStrength = maybeGetStr() + r.KeyLockingDurability = maybeGetDur() return &r case "scan": @@ -194,6 +206,7 @@ func scanSingleRequest( r.EndKey = roachpb.Key(v) } r.KeyLockingStrength = maybeGetStr() + r.KeyLockingDurability = maybeGetDur() return &r case "put": diff --git a/pkg/kv/kvserver/concurrency/testdata/concurrency_manager/shared_locks_latches b/pkg/kv/kvserver/concurrency/testdata/concurrency_manager/shared_locks_latches index 42b811010096..9781a62be7e0 100644 --- a/pkg/kv/kvserver/concurrency/testdata/concurrency_manager/shared_locks_latches +++ b/pkg/kv/kvserver/concurrency/testdata/concurrency_manager/shared_locks_latches @@ -641,3 +641,78 @@ finish req=req33 finish req=req34 ---- [-] finish req34: finishing request + +# ------------------------------------------------------------------------------ +# Ensure concurrent replicated shared locking requests by the same transaction +# conflict on latches. Also ensure concurrent replicated shared lock attempts +# by different transactions do not. +# ------------------------------------------------------------------------------ + +new-request name=req35 txn=txn2 ts=11,1 + get key=c str=shared dur=r +---- + +sequence req=req35 +---- +[35] sequence req35: sequencing request +[35] sequence req35: acquiring latches +[35] sequence req35: scanning lock table for conflicting locks +[35] sequence req35: sequencing complete, returned guard + +new-request name=req36 txn=txn2 ts=11,1 + scan key=a endkey=f str=shared dur=r +---- + +sequence req=req36 +---- +[36] sequence req36: sequencing request +[36] sequence req36: acquiring latches +[36] sequence req36: waiting to acquire write latch ‹/Local/RangeID/1/r/ReplicatedSharedLocksTransactionLatch/"00000002-0000-0000-0000-000000000000"›@0,0, held by write latch ‹/Local/RangeID/1/r/ReplicatedSharedLocksTransactionLatch/"00000002-0000-0000-0000-000000000000"›@0,0 +[36] sequence req36: blocked on select in spanlatch.(*Manager).waitForSignal + +new-request name=req37 txn=txn1 ts=11,1 + get key=c str=shared dur=r +---- + +sequence req=req37 +---- +[37] sequence req37: sequencing request +[37] sequence req37: acquiring latches +[37] sequence req37: scanning lock table for conflicting locks +[37] sequence req37: sequencing complete, returned guard + + +# Unreplicated shared locking request from txn2. Shouldn't conflict on latches. +new-request name=req38 txn=txn2 ts=11,1 + get key=c str=shared dur=u +---- + +sequence req=req38 +---- +[38] sequence req38: sequencing request +[38] sequence req38: acquiring latches +[38] sequence req38: scanning lock table for conflicting locks +[38] sequence req38: sequencing complete, returned guard + +debug-latch-manager +---- +write count: 3 + read count: 4 + +finish req=req35 +---- +[-] finish req35: finishing request +[36] sequence req36: scanning lock table for conflicting locks +[36] sequence req36: sequencing complete, returned guard + +finish req=req36 +---- +[-] finish req36: finishing request + +finish req=req37 +---- +[-] finish req37: finishing request + +finish req=req38 +---- +[-] finish req38: finishing request diff --git a/pkg/sql/conn_executor.go b/pkg/sql/conn_executor.go index 9ab7ad6405a0..d86314936247 100644 --- a/pkg/sql/conn_executor.go +++ b/pkg/sql/conn_executor.go @@ -1214,13 +1214,16 @@ func (ex *connExecutor) close(ctx context.Context, closeType closeType) { txnEvType = txnRollback } - // Close all portals, otherwise there will be leftover bytes. + // Close all portals and cursors, otherwise there will be leftover bytes. ex.extraTxnState.prepStmtsNamespace.closeAllPortals( ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, ) ex.extraTxnState.prepStmtsNamespaceAtTxnRewindPos.closeAllPortals( ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, ) + if err := ex.extraTxnState.sqlCursors.closeAll(false /* errorOnWithHold */); err != nil { + log.Warningf(ctx, "error closing cursors: %v", err) + } var payloadErr error if closeType == normalClose { @@ -1271,7 +1274,8 @@ func (ex *connExecutor) close(ctx context.Context, closeType closeType) { } if closeType != panicClose { - // Close all statements, prepared portals, and cursors. + // Close all statements and prepared portals. The cursors have already been + // closed. ex.extraTxnState.prepStmtsNamespace.resetToEmpty( ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, ) @@ -1279,9 +1283,6 @@ func (ex *connExecutor) close(ctx context.Context, closeType closeType) { ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, ) ex.extraTxnState.prepStmtsNamespaceMemAcc.Close(ctx) - if err := ex.extraTxnState.sqlCursors.closeAll(false /* errorOnWithHold */); err != nil { - log.Warningf(ctx, "error closing cursors: %v", err) - } } if ex.sessionTracing.Enabled() { diff --git a/pkg/sql/logictest/testdata/logic_test/plpgsql_cursor b/pkg/sql/logictest/testdata/logic_test/plpgsql_cursor new file mode 100644 index 000000000000..d8d4684a3794 --- /dev/null +++ b/pkg/sql/logictest/testdata/logic_test/plpgsql_cursor @@ -0,0 +1,400 @@ +statement ok +CREATE TABLE xy (x INT, y INT); +INSERT INTO xy VALUES (1, 2), (3, 4); + +statement ok +CREATE TABLE kv (k INT PRIMARY KEY, v INT); +INSERT INTO kv VALUES (1, 2), (3, 4); + +# Testing OPEN statements. +statement ok +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query I +FETCH FORWARD 3 FROM foo; +---- +1 + +statement ok +ABORT; + +statement error pgcode 34000 pq: cursor \"foo\" does not exist +FETCH FORWARD 3 FROM foo; + +statement ok +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + x INT := 10; + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT x; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query I +FETCH FORWARD 3 FROM foo; +---- +10 + +# TODO(drewk): postgres returns an ambiguous column error here by default, +# although it can be configured to prefer either the variable or the column. +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + x INT := 10; + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT * FROM xy WHERE xy.x = x; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query II rowsort +FETCH FORWARD 10 FROM foo; +---- +1 2 +3 4 + +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + i INT := 3; + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT * FROM xy WHERE x = i; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query II +FETCH FORWARD 3 FROM foo; +---- +3 4 + +# It should be possible to fetch from the cursor incrementally. +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs NO SCROLL FOR SELECT * FROM (VALUES (1, 2), (3, 4), (5, 6), (7, 8)) v(a, b); + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query II rowsort +FETCH FORWARD 1 FROM foo; +---- +1 2 + +query II rowsort +FETCH FORWARD 2 FROM foo; +---- +3 4 +5 6 + +query II rowsort +FETCH FORWARD 3 FROM foo; +---- +7 8 + +# Cursor with NO SCROLL option. +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs NO SCROLL FOR SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query I +FETCH FORWARD 3 FROM foo; +---- +1 + +# Cursor with empty-string name. +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := ''; + BEGIN + OPEN curs FOR SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query I +FETCH FORWARD 3 FROM ""; +---- +1 + +# Multiple cursors. +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + curs2 STRING := 'bar'; + curs3 STRING := 'baz'; + BEGIN + OPEN curs FOR SELECT 1; + OPEN curs2 FOR SELECT 2; + OPEN curs3 FOR SELECT 3; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query I +FETCH FORWARD 3 FROM foo; +---- +1 + +query I +FETCH FORWARD 3 FROM bar; +---- +2 + +query I +FETCH FORWARD 3 FROM baz; +---- +3 + +# The cursor should reflect changes to the database state that occur before +# it is opened, but not those that happen after it is opened. +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + curs2 STRING := 'bar'; + curs3 STRING := 'baz'; + BEGIN + OPEN curs FOR SELECT * FROM xy WHERE x = 99; + INSERT INTO xy VALUES (99, 99); + OPEN curs2 FOR SELECT * FROM xy WHERE x = 99; + UPDATE xy SET y = 100 WHERE x = 99; + OPEN curs3 FOR SELECT * FROM xy WHERE x = 99; + DELETE FROM xy WHERE x = 99; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; +SELECT f(); + +query II +FETCH FORWARD 3 FROM foo; +---- + +query II +FETCH FORWARD 3 FROM bar; +---- +99 99 + +query II +FETCH FORWARD 3 FROM baz; +---- +99 100 + +query II rowsort +SELECT * FROM xy; +---- +1 2 +3 4 + +statement ok +ABORT; + +# It is possible to use the OPEN statement in an implicit transaction, but the +# cursor is closed at the end of the transaction when the statement execution +# finishes. So, until FETCH is implemented, we can't actually read from the +# cursor. +statement ok +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +SELECT f(); + +statement error pgcode 34000 pq: cursor \"foo\" does not exist +FETCH FORWARD 5 FROM foo; + +statement error pgcode 0A000 pq: unimplemented: DECLARE SCROLL CURSOR +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs SCROLL FOR SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; + +statement error pgcode 0A000 pq: unimplemented: bound cursor declarations are not yet supported. +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs CURSOR FOR SELECT 1; + BEGIN + OPEN curs; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; + +statement error pgcode 42P11 pq: cannot open INSERT query as cursor +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs FOR INSERT INTO xy VALUES (1, 1); + RETURN 0; + END +$$ LANGUAGE PLpgSQL; + +statement error pgcode 0A000 pq: unimplemented: CTE usage inside a function definition +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + i INT := 3; + curs STRING := 'foo'; + BEGIN + OPEN curs FOR WITH foo AS (SELECT * FROM xy WHERE x = i) SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; + +# TODO(drewk): once CTEs in routines are supported, the error should be: +# pgcode 0A000 pq: DECLARE CURSOR must not contain data-modifying statements in WITH +statement error pgcode 0A000 pq: unimplemented: CTE usage inside a function definition +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + i INT := 3; + curs STRING := 'foo'; + BEGIN + OPEN curs FOR WITH foo AS (INSERT INTO xy VALUES (1, 1) RETURNING x) SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; + +statement error pgcode 42601 pq: \"curs\" is not a known variable +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + i INT := 3; + BEGIN + OPEN curs FOR WITH foo AS (SELECT * FROM xy WHERE x = i) SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; + +statement ok +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING; + BEGIN + OPEN curs FOR SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; + +statement error pgcode 0A000 pq: unimplemented: opening an unnamed cursor is not yet supported +SELECT f(); + +statement ok +ABORT; + +statement error pgcode 0A000 pq: unimplemented: opening a cursor in a routine with an exception block is not yet supported +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING; + BEGIN + OPEN curs FOR SELECT 1; + RETURN 0; + EXCEPTION + WHEN division_by_zero THEN + RETURN -1; + END +$$ LANGUAGE PLpgSQL; +BEGIN; + +statement ok +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT 1 // 0; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; + +statement error pgcode 22012 pq: division by zero +SELECT f(); + +# Conflict with an existing cursor. +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; + +statement ok +DECLARE foo CURSOR FOR SELECT 100; + +statement error pgcode 42P03 pq: cursor \"foo\" already exists +SELECT f(); + +# Conflict between OPEN statements within the same routine. +statement ok +ABORT; +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + curs2 STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT 1; + OPEN curs2 FOR SELECT 2; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +BEGIN; + +statement error pgcode 42P03 pq: cursor \"foo\" already exists +SELECT f(); diff --git a/pkg/sql/logictest/tests/fakedist-disk/generated_test.go b/pkg/sql/logictest/tests/fakedist-disk/generated_test.go index 88e18d338e3c..2f4e89470795 100644 --- a/pkg/sql/logictest/tests/fakedist-disk/generated_test.go +++ b/pkg/sql/logictest/tests/fakedist-disk/generated_test.go @@ -1359,6 +1359,13 @@ func TestLogic_pgoidtype( runLogicTest(t, "pgoidtype") } +func TestLogic_plpgsql_cursor( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "plpgsql_cursor") +} + func TestLogic_poison_after_push( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go b/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go index 926db53b5f10..992eb5c5786b 100644 --- a/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go +++ b/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go @@ -1359,6 +1359,13 @@ func TestLogic_pgoidtype( runLogicTest(t, "pgoidtype") } +func TestLogic_plpgsql_cursor( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "plpgsql_cursor") +} + func TestLogic_poison_after_push( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/fakedist/generated_test.go b/pkg/sql/logictest/tests/fakedist/generated_test.go index f49d57d715cf..503ac7fe8b4f 100644 --- a/pkg/sql/logictest/tests/fakedist/generated_test.go +++ b/pkg/sql/logictest/tests/fakedist/generated_test.go @@ -1373,6 +1373,13 @@ func TestLogic_pgoidtype( runLogicTest(t, "pgoidtype") } +func TestLogic_plpgsql_cursor( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "plpgsql_cursor") +} + func TestLogic_poison_after_push( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go b/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go index 69939d73d853..0baf52e63571 100644 --- a/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go +++ b/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go @@ -1345,6 +1345,13 @@ func TestLogic_pgoidtype( runLogicTest(t, "pgoidtype") } +func TestLogic_plpgsql_cursor( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "plpgsql_cursor") +} + func TestLogic_poison_after_push( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/local-mixed-22.2-23.1/generated_test.go b/pkg/sql/logictest/tests/local-mixed-22.2-23.1/generated_test.go index f642f9455c1c..39aa14d9c6af 100644 --- a/pkg/sql/logictest/tests/local-mixed-22.2-23.1/generated_test.go +++ b/pkg/sql/logictest/tests/local-mixed-22.2-23.1/generated_test.go @@ -1338,6 +1338,13 @@ func TestLogic_pgoidtype( runLogicTest(t, "pgoidtype") } +func TestLogic_plpgsql_cursor( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "plpgsql_cursor") +} + func TestLogic_poison_after_push( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/local-vec-off/generated_test.go b/pkg/sql/logictest/tests/local-vec-off/generated_test.go index 29fe95596431..b12841aafb75 100644 --- a/pkg/sql/logictest/tests/local-vec-off/generated_test.go +++ b/pkg/sql/logictest/tests/local-vec-off/generated_test.go @@ -1373,6 +1373,13 @@ func TestLogic_pgoidtype( runLogicTest(t, "pgoidtype") } +func TestLogic_plpgsql_cursor( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "plpgsql_cursor") +} + func TestLogic_poison_after_push( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/local/generated_test.go b/pkg/sql/logictest/tests/local/generated_test.go index 523563e3db47..78b0a226c816 100644 --- a/pkg/sql/logictest/tests/local/generated_test.go +++ b/pkg/sql/logictest/tests/local/generated_test.go @@ -1499,6 +1499,13 @@ func TestLogic_pgoidtype( runLogicTest(t, "pgoidtype") } +func TestLogic_plpgsql_cursor( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "plpgsql_cursor") +} + func TestLogic_poison_after_push( t *testing.T, ) { diff --git a/pkg/sql/opt/exec/execbuilder/relational.go b/pkg/sql/opt/exec/execbuilder/relational.go index b1e2007e95f5..b9ec50bd5257 100644 --- a/pkg/sql/opt/exec/execbuilder/relational.go +++ b/pkg/sql/opt/exec/execbuilder/relational.go @@ -3168,6 +3168,7 @@ func (b *Builder) buildCall(c *memo.CallExpr) (execPlan, error) { udf.TailCall, true, /* procedure */ nil, /* exceptionHandler */ + nil, /* cursorDeclaration */ ) var ep execPlan diff --git a/pkg/sql/opt/exec/execbuilder/scalar.go b/pkg/sql/opt/exec/execbuilder/scalar.go index 545d15321de1..b8a961483907 100644 --- a/pkg/sql/opt/exec/execbuilder/scalar.go +++ b/pkg/sql/opt/exec/execbuilder/scalar.go @@ -703,6 +703,7 @@ func (b *Builder) buildExistsSubquery( false, /* tailCall */ false, /* procedure */ nil, /* exceptionHandler */ + nil, /* cursorDeclaration */ ), tree.DBoolFalse, }, types.Bool), nil @@ -821,6 +822,7 @@ func (b *Builder) buildSubquery( false, /* tailCall */ false, /* procedure */ nil, /* exceptionHandler */ + nil, /* cursorDeclaration */ ), nil } @@ -878,6 +880,7 @@ func (b *Builder) buildSubquery( false, /* tailCall */ false, /* procedure */ nil, /* exceptionHandler */ + nil, /* cursorDeclaration */ ), nil } @@ -994,6 +997,7 @@ func (b *Builder) buildUDF(ctx *buildScalarCtx, scalar opt.ScalarExpr) (tree.Typ false, /* tailCall */ false, /* procedure */ nil, /* exceptionHandler */ + nil, /* cursorDeclaration */ ) } } @@ -1010,6 +1014,7 @@ func (b *Builder) buildUDF(ctx *buildScalarCtx, scalar opt.ScalarExpr) (tree.Typ udf.TailCall, false, /* procedure */ exceptionHandler, + udf.Def.CursorDeclaration, ), nil } diff --git a/pkg/sql/opt/memo/expr.go b/pkg/sql/opt/memo/expr.go index 042bea089ee6..cabde2c6c0f8 100644 --- a/pkg/sql/opt/memo/expr.go +++ b/pkg/sql/opt/memo/expr.go @@ -725,6 +725,13 @@ type UDFDefinition struct { // ExceptionBlock contains information needed for exception-handling when the // body of this routine returns an error. It can be unset. ExceptionBlock *ExceptionBlock + + // CursorDeclaration contains the information needed to open a SQL cursor with + // the result of the *first* body statement. If it is set, there will be at + // least two body statements - one to open the cursor, and one to evaluate the + // result of the routine. This invariant is enforced when the PLpgSQL routine + // is built. CursorDeclaration may be unset. + CursorDeclaration *tree.RoutineOpenCursor } // ExceptionBlock contains the information needed to match and handle errors in diff --git a/pkg/sql/opt/memo/expr_format.go b/pkg/sql/opt/memo/expr_format.go index 1cf4b6030dbd..7d38476e9076 100644 --- a/pkg/sql/opt/memo/expr_format.go +++ b/pkg/sql/opt/memo/expr_format.go @@ -957,6 +957,12 @@ func (f *ExprFmtCtx) formatScalarWithLabel( } n = tp.Child("body") for i := range udf.Def.Body { + if i == 0 && udf.Def.CursorDeclaration != nil { + // The first statement is opening a cursor. + cur := n.Child("open-cursor") + f.formatExpr(udf.Def.Body[i], cur) + continue + } f.formatExpr(udf.Def.Body[i], n) } delete(f.seenUDFs, udf.Def) diff --git a/pkg/sql/opt/memo/interner.go b/pkg/sql/opt/memo/interner.go index cb5f7f847ba7..f71d64d3d8b9 100644 --- a/pkg/sql/opt/memo/interner.go +++ b/pkg/sql/opt/memo/interner.go @@ -1231,6 +1231,33 @@ func (h *hasher) IsUDFDefinitionEqual(l, r *UDFDefinition) bool { return false } } + if l.ExceptionBlock != nil { + if r.ExceptionBlock == nil || len(l.ExceptionBlock.Actions) != len(r.ExceptionBlock.Actions) { + return false + } + for i := range l.ExceptionBlock.Actions { + if !h.IsUDFDefinitionEqual(l.ExceptionBlock.Actions[i], r.ExceptionBlock.Actions[i]) { + return false + } + if l.ExceptionBlock.Codes[i] != r.ExceptionBlock.Codes[i] { + return false + } + } + } else if r.ExceptionBlock != nil { + return false + } + if l.CursorDeclaration != nil { + if r.CursorDeclaration == nil { + return false + } + if l.CursorDeclaration.NameArgIdx != r.CursorDeclaration.NameArgIdx || + l.CursorDeclaration.Scroll != r.CursorDeclaration.Scroll || + l.CursorDeclaration.CursorSQL != r.CursorDeclaration.CursorSQL { + return false + } + } else if r.CursorDeclaration != nil { + return false + } return h.IsColListEqual(l.Params, r.Params) && l.IsRecursive == r.IsRecursive } diff --git a/pkg/sql/opt/norm/inline_funcs.go b/pkg/sql/opt/norm/inline_funcs.go index a7cd103abd91..291a47525aa8 100644 --- a/pkg/sql/opt/norm/inline_funcs.go +++ b/pkg/sql/opt/norm/inline_funcs.go @@ -413,7 +413,8 @@ func (c *CustomFuncs) InlineConstVar(f memo.FiltersExpr) memo.FiltersExpr { // 4. Its arguments are only Variable or Const expressions. // 5. It is not a record-returning function. // 6. It does not recursively call itself. -// 7. It does not have an exception-handling block. +// 7. It does not open a cursor. +// 8. It does not have an exception-handling block. // // UDFs with mutations (INSERT, UPDATE, UPSERT, DELETE) cannot be inlined, but // we do not need an explicit check for this because immutable UDFs cannot @@ -448,7 +449,7 @@ func (c *CustomFuncs) IsInlinableUDF(args memo.ScalarListExpr, udfp *memo.UDFCal } if udfp.Def.IsRecursive || udfp.Def.Volatility == volatility.Volatile || len(udfp.Def.Body) != 1 || udfp.Def.SetReturning || udfp.Def.MultiColDataSource || - udfp.Def.ExceptionBlock != nil { + udfp.Def.CursorDeclaration != nil || udfp.Def.ExceptionBlock != nil { return false } if !args.IsConstantsAndPlaceholdersAndVariables() { diff --git a/pkg/sql/opt/optbuilder/plpgsql.go b/pkg/sql/opt/optbuilder/plpgsql.go index a9ee8f084aa6..9c66ebaf0866 100644 --- a/pkg/sql/opt/optbuilder/plpgsql.go +++ b/pkg/sql/opt/optbuilder/plpgsql.go @@ -152,7 +152,14 @@ func (b *plpgsqlBuilder) init( b.ob = ob b.colRefs = colRefs b.params = params - b.decls = block.Decls + for i := range block.Decls { + switch dec := block.Decls[i].(type) { + case *ast.Declaration: + b.decls = append(b.decls, *dec) + case *ast.CursorDeclaration: + panic(unimplemented.New("bound cursors", "bound cursor declarations are not yet supported.")) + } + } b.returnType = returnType b.varTypes = make(map[tree.Name]*types.T) for _, dec := range b.decls { @@ -422,6 +429,7 @@ func (b *plpgsqlBuilder) buildPLpgSQLStatements(stmts []ast.Statement, s *scope) // The synchronous notice sending behavior is implemented in the // crdb_internal.plpgsql_raise builtin function. con := b.makeContinuation("_stmt_raise") + con.def.Volatility = volatility.Volatile b.appendBodyStmt(&con, b.buildPLpgSQLRaise(con.s, b.getRaiseArgs(con.s, t))) b.appendPlpgSQLStmts(&con, stmts[i+1:]) return b.callContinuation(&con, s) @@ -499,6 +507,57 @@ func (b *plpgsqlBuilder) buildPLpgSQLStatements(stmts []ast.Statement, s *scope) b.appendBodyStmt(&execCon, intoScope) return b.callContinuation(&execCon, s) + case *ast.Open: + // OPEN statements are used to create a CURSOR for the current session. + // This is handled by calling the plpgsql_open_cursor internal builtin + // function in a separate body statement that returns no results, similar + // to the RAISE implementation. + if b.exceptionBlock != nil { + panic(unimplemented.New("open with exception block", + "opening a cursor in a routine with an exception block is not yet supported", + )) + } + if t.Scroll == tree.Scroll { + panic(unimplemented.NewWithIssue(77102, "DECLARE SCROLL CURSOR")) + } + if t.Query == nil { + panic(unimplemented.New("bound cursor", "opening a bound cursor is not yet supported")) + } + if _, ok := t.Query.(*tree.Select); !ok { + panic(pgerror.Newf( + pgcode.InvalidCursorDefinition, "cannot open %s query as cursor", + t.Query.StatementTag(), + )) + } + openCon := b.makeContinuation("_stmt_open") + openCon.def.Volatility = volatility.Volatile + fmtCtx := b.ob.evalCtx.FmtCtx(tree.FmtSimple) + fmtCtx.FormatNode(t.Query) + _, source, _, err := openCon.s.FindSourceProvidingColumn(b.ob.ctx, t.CurVar) + if err != nil { + if pgerror.GetPGCode(err) == pgcode.UndefinedColumn { + panic(pgerror.Newf(pgcode.Syntax, "\"%s\" is not a known variable", t.CurVar)) + } + panic(err) + } + // Initialize the routine with the information needed to pipe the first + // body statement into a cursor. + openCon.def.CursorDeclaration = &tree.RoutineOpenCursor{ + NameArgIdx: source.(*scopeColumn).getParamOrd(), + Scroll: t.Scroll, + CursorSQL: fmtCtx.CloseAndGetString(), + } + openScope := b.ob.buildStmtAtRootWithScope(t.Query, nil /* desiredTypes */, openCon.s) + if openScope.expr.Relational().CanMutate { + // Cursors with mutations are invalid. + panic(pgerror.Newf(pgcode.FeatureNotSupported, + "DECLARE CURSOR must not contain data-modifying statements in WITH", + )) + } + b.appendBodyStmt(&openCon, openScope) + b.appendPlpgSQLStmts(&openCon, stmts[i+1:]) + return b.callContinuation(&openCon, s) + default: panic(unimplemented.New( "unimplemented PL/pgSQL statement", @@ -555,6 +614,7 @@ func (b *plpgsqlBuilder) buildPLpgSQLRaise(inScope *scope, args memo.ScalarListE ) raiseColName := scopeColName("").WithMetadataName(b.makeIdentifier("stmt_raise")) raiseScope := inScope.push() + b.ensureScopeHasExpr(raiseScope) b.ob.synthesizeColumn(raiseScope, raiseColName, types.Int, nil /* expr */, raiseCall) b.ob.constructProjectForScope(inScope, raiseScope) return raiseScope @@ -822,6 +882,7 @@ func (b *plpgsqlBuilder) buildEndOfFunctionRaise(inScope *scope) *scope { makeConstStr(pgcode.RoutineExceptionFunctionExecutedNoReturnStatement.String()), /* code */ } con := b.makeContinuation("_end_of_function") + con.def.Volatility = volatility.Volatile b.appendBodyStmt(&con, b.buildPLpgSQLRaise(con.s, args)) // Build a dummy statement that returns NULL. It won't be executed, but // ensures that the continuation routine's return type is correct. @@ -912,11 +973,7 @@ func (b *plpgsqlBuilder) callContinuation(con *continuation, s *scope) *scope { if err != nil { panic(err) } - if source != nil { - args = append(args, b.ob.factory.ConstructVariable(source.(*scopeColumn).id)) - } else { - args = append(args, b.ob.factory.ConstructNull(typ)) - } + args = append(args, b.ob.factory.ConstructVariable(source.(*scopeColumn).id)) } for _, dec := range b.decls { addArg(dec.Var, b.varTypes[dec.Var]) diff --git a/pkg/sql/opt/optbuilder/scope_column.go b/pkg/sql/opt/optbuilder/scope_column.go index 4ca6686953a3..1074acd3c1e3 100644 --- a/pkg/sql/opt/optbuilder/scope_column.go +++ b/pkg/sql/opt/optbuilder/scope_column.go @@ -135,6 +135,15 @@ func (c *scopeColumn) setParamOrd(ord int) { c.paramOrd = funcParamOrd(ord + 1) } +// getParamOrd retrieves the 0-based ordinal from the column's 1-based function +// parameter ordinal. Panics if the function ordinal is unset. +func (c *scopeColumn) getParamOrd() int { + if c.paramOrd < 1 { + panic(errors.AssertionFailedf("expected non-negative argument ordinal")) + } + return int(c.paramOrd) - 1 +} + // funcParamReferencedBy returns true if the scopeColumn is a function parameter // column that can be referenced by the given placeholder. func (c *scopeColumn) funcParamReferencedBy(idx tree.PlaceholderIdx) bool { diff --git a/pkg/sql/opt/optbuilder/testdata/udf_plpgsql b/pkg/sql/opt/optbuilder/testdata/udf_plpgsql index eacc4350caeb..4c21e18981c3 100644 --- a/pkg/sql/opt/optbuilder/testdata/udf_plpgsql +++ b/pkg/sql/opt/optbuilder/testdata/udf_plpgsql @@ -4603,3 +4603,222 @@ project │ └── projections │ └── const: 0 [as=stmt_return_9:20] └── const: 1 + +# Testing OPEN statement. +exec-ddl +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT 1; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +---- + +build format=show-scalars +SELECT f(); +---- +project + ├── columns: f:6 + ├── values + │ └── tuple + └── projections + └── udf: f [as=f:6] + └── body + └── limit + ├── columns: "_stmt_open_1":5 + ├── project + │ ├── columns: "_stmt_open_1":5 + │ ├── project + │ │ ├── columns: curs:1!null + │ │ ├── values + │ │ │ └── tuple + │ │ └── projections + │ │ └── const: 'foo' [as=curs:1] + │ └── projections + │ └── udf: _stmt_open_1 [as="_stmt_open_1":5] + │ ├── args + │ │ └── variable: curs:1 + │ ├── params: curs:2 + │ └── body + │ ├── open-cursor + │ │ └── project + │ │ ├── columns: "?column?":3!null + │ │ ├── values + │ │ │ └── tuple + │ │ └── projections + │ │ └── const: 1 [as="?column?":3] + │ └── project + │ ├── columns: stmt_return_2:4!null + │ ├── values + │ │ └── tuple + │ └── projections + │ └── const: 0 [as=stmt_return_2:4] + └── const: 1 + +exec-ddl +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + i INT := 3; + curs STRING := 'foo'; + BEGIN + OPEN curs FOR SELECT * FROM xy WHERE x = i; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +---- + +build format=show-scalars +SELECT f(); +---- +project + ├── columns: f:12 + ├── values + │ └── tuple + └── projections + └── udf: f [as=f:12] + └── body + └── limit + ├── columns: "_stmt_open_1":11 + ├── project + │ ├── columns: "_stmt_open_1":11 + │ ├── project + │ │ ├── columns: curs:2!null i:1!null + │ │ ├── project + │ │ │ ├── columns: i:1!null + │ │ │ ├── values + │ │ │ │ └── tuple + │ │ │ └── projections + │ │ │ └── const: 3 [as=i:1] + │ │ └── projections + │ │ └── const: 'foo' [as=curs:2] + │ └── projections + │ └── udf: _stmt_open_1 [as="_stmt_open_1":11] + │ ├── args + │ │ ├── variable: i:1 + │ │ └── variable: curs:2 + │ ├── params: i:3 curs:4 + │ └── body + │ ├── open-cursor + │ │ └── project + │ │ ├── columns: x:5!null y:6 + │ │ └── select + │ │ ├── columns: x:5!null y:6 rowid:7!null crdb_internal_mvcc_timestamp:8 tableoid:9 + │ │ ├── scan xy + │ │ │ └── columns: x:5 y:6 rowid:7!null crdb_internal_mvcc_timestamp:8 tableoid:9 + │ │ └── filters + │ │ └── eq + │ │ ├── variable: x:5 + │ │ └── variable: i:3 + │ └── project + │ ├── columns: stmt_return_2:10!null + │ ├── values + │ │ └── tuple + │ └── projections + │ └── const: 0 [as=stmt_return_2:10] + └── const: 1 + +exec-ddl +CREATE OR REPLACE FUNCTION f() RETURNS INT AS $$ + DECLARE + curs STRING := 'foo'; + curs2 STRING := 'bar'; + curs3 STRING := 'baz'; + BEGIN + OPEN curs FOR SELECT 1; + OPEN curs2 FOR SELECT 2; + OPEN curs3 FOR SELECT 3; + RETURN 0; + END +$$ LANGUAGE PLpgSQL; +---- + +build format=show-scalars +SELECT f(); +---- +project + ├── columns: f:20 + ├── values + │ └── tuple + └── projections + └── udf: f [as=f:20] + └── body + └── limit + ├── columns: "_stmt_open_1":19 + ├── project + │ ├── columns: "_stmt_open_1":19 + │ ├── project + │ │ ├── columns: curs3:3!null curs:1!null curs2:2!null + │ │ ├── project + │ │ │ ├── columns: curs2:2!null curs:1!null + │ │ │ ├── project + │ │ │ │ ├── columns: curs:1!null + │ │ │ │ ├── values + │ │ │ │ │ └── tuple + │ │ │ │ └── projections + │ │ │ │ └── const: 'foo' [as=curs:1] + │ │ │ └── projections + │ │ │ └── const: 'bar' [as=curs2:2] + │ │ └── projections + │ │ └── const: 'baz' [as=curs3:3] + │ └── projections + │ └── udf: _stmt_open_1 [as="_stmt_open_1":19] + │ ├── args + │ │ ├── variable: curs:1 + │ │ ├── variable: curs2:2 + │ │ └── variable: curs3:3 + │ ├── params: curs:4 curs2:5 curs3:6 + │ └── body + │ ├── open-cursor + │ │ └── project + │ │ ├── columns: "?column?":7!null + │ │ ├── values + │ │ │ └── tuple + │ │ └── projections + │ │ └── const: 1 [as="?column?":7] + │ └── project + │ ├── columns: "_stmt_open_2":18 + │ ├── values + │ │ └── tuple + │ └── projections + │ └── udf: _stmt_open_2 [as="_stmt_open_2":18] + │ ├── args + │ │ ├── variable: curs:4 + │ │ ├── variable: curs2:5 + │ │ └── variable: curs3:6 + │ ├── params: curs:8 curs2:9 curs3:10 + │ └── body + │ ├── open-cursor + │ │ └── project + │ │ ├── columns: "?column?":11!null + │ │ ├── values + │ │ │ └── tuple + │ │ └── projections + │ │ └── const: 2 [as="?column?":11] + │ └── project + │ ├── columns: "_stmt_open_3":17 + │ ├── values + │ │ └── tuple + │ └── projections + │ └── udf: _stmt_open_3 [as="_stmt_open_3":17] + │ ├── args + │ │ ├── variable: curs:8 + │ │ ├── variable: curs2:9 + │ │ └── variable: curs3:10 + │ ├── params: curs:12 curs2:13 curs3:14 + │ └── body + │ ├── open-cursor + │ │ └── project + │ │ ├── columns: "?column?":15!null + │ │ ├── values + │ │ │ └── tuple + │ │ └── projections + │ │ └── const: 3 [as="?column?":15] + │ └── project + │ ├── columns: stmt_return_4:16!null + │ ├── values + │ │ └── tuple + │ └── projections + │ └── const: 0 [as=stmt_return_4:16] + └── const: 1 diff --git a/pkg/sql/plpgsql/parser/lexer.go b/pkg/sql/plpgsql/parser/lexer.go index 442f85d5e0e3..517540e5d1ed 100644 --- a/pkg/sql/plpgsql/parser/lexer.go +++ b/pkg/sql/plpgsql/parser/lexer.go @@ -246,56 +246,6 @@ func (l *lexer) MakeDynamicExecuteStmt() *plpgsqltree.DynamicExecute { return ret } -func (l *lexer) ProcessForOpenCursor(nullCursorExplicitExpr bool) *plpgsqltree.Open { - openStmt := &plpgsqltree.Open{} - openStmt.CursorOptions = plpgsqltree.CursorOptionFastPlan.Mask() - - if nullCursorExplicitExpr { - if l.Peek().id == NO { - l.lastPos++ - if l.Peek().id == SCROLL { - openStmt.CursorOptions |= plpgsqltree.CursorOptionNoScroll.Mask() - l.lastPos++ - } - } else if l.Peek().id == SCROLL { - openStmt.CursorOptions |= plpgsqltree.CursorOptionScroll.Mask() - l.lastPos++ - } - - if l.Peek().id != FOR { - l.setErr(pgerror.New(pgcode.Syntax, "syntax error, expected \"FOR\"")) - return nil - } - - l.lastPos++ - if l.Peek().id == EXECUTE { - l.lastPos++ - dynamicQuery, endToken := l.ReadSqlExpressionStr2(USING, ';') - openStmt.DynamicQuery = dynamicQuery - l.lastPos++ - if endToken == USING { - // Continue reading for params for the sql expression till the ending - // token is not a comma. - openStmt.Params = make([]string, 0) - for { - param, endToken := l.ReadSqlExpressionStr2(',', ';') - openStmt.Params = append(openStmt.Params, param) - if endToken != ',' { - break - } - l.lastPos++ - } - } - } else { - openStmt.Query = l.ReadSqlExpressionStr(';') - } - } else { - // read_cursor_args() - openStmt.ArgQuery = "hello" - } - return openStmt -} - // ReadSqlExpressionStr returns the string from the l.lastPos till it sees // the terminator for the first time. The returned string is made by tokens // between the starting index (included) to the terminator (not included). @@ -360,6 +310,62 @@ func (l *lexer) readSQLConstruct( return startPos, endPos, terminatorMet } +func (l *lexer) MakeFetchOrMoveStmt(isMove bool) (plpgsqltree.Statement, error) { + if l.parser.Lookahead() != -1 { + // Push back the lookahead token so that it can be included. + l.PushBack(1) + } + prefix := "FETCH " + if isMove { + prefix = "MOVE " + } + sqlStr, terminator := l.ReadSqlConstruct(INTO, ';') + sqlStr = prefix + sqlStr + sqlStmt, err := parser.ParseOne(sqlStr) + if err != nil { + return nil, err + } + var cursor tree.CursorStmt + switch t := sqlStmt.AST.(type) { + case *tree.FetchCursor: + cursor = t.CursorStmt + case *tree.MoveCursor: + cursor = t.CursorStmt + default: + return nil, errors.Newf("invalid FETCH or MOVE syntax") + } + var target []plpgsqltree.Variable + if !isMove { + if terminator != INTO { + return nil, errors.Newf("invalid syntax for FETCH") + } + // Read past the INTO. + l.lastPos++ + startPos, endPos, _ := l.readSQLConstruct(';') + for pos := startPos; pos < endPos; pos += 2 { + tok := l.tokens[pos] + if tok.id != IDENT { + return nil, errors.Newf("\"%s\" is not a scalar variable", tok.str) + } + if pos+1 != endPos && l.tokens[pos+1].id != ',' { + return nil, errors.Newf("expected INTO target to be a comma-separated list") + } + variable := plpgsqltree.Variable(strings.TrimSpace(l.getStr(pos, pos+1))) + target = append(target, variable) + } + if len(target) == 0 { + return nil, errors.Newf("expected INTO target") + } + } + // Move past the semicolon. + l.lastPos++ + return &plpgsqltree.Fetch{ + Cursor: cursor, + Target: target, + IsMove: isMove, + }, nil +} + func (l *lexer) ReadSqlConstruct( terminator1 int, terminators ...int, ) (sqlStr string, terminatorMet int) { @@ -380,26 +386,6 @@ func (l *lexer) getStr(startPos, endPos int) string { return l.in[start:end] } -func (l *lexer) ProcessQueryForCursorWithoutExplicitExpr(openStmt *plpgsqltree.Open) { - l.lastPos++ - if int(l.Peek().id) == EXECUTE { - dynamicQuery, endToken := l.ReadSqlExpressionStr2(USING, ';') - openStmt.DynamicQuery = dynamicQuery - if endToken == USING { - var expr string - for { - expr, endToken = l.ReadSqlExpressionStr2(',', ';') - openStmt.Params = append(openStmt.Params, expr) - if endToken != ',' { - break - } - } - } - } else { - openStmt.Query = l.ReadSqlExpressionStr(';') - } -} - // Peek peeks func (l *lexer) Peek() plpgsqlSymType { if l.lastPos+1 < len(l.tokens) { diff --git a/pkg/sql/plpgsql/parser/plpgsql.y b/pkg/sql/plpgsql/parser/plpgsql.y index 9cfc31443304..5ddc6a4698af 100644 --- a/pkg/sql/plpgsql/parser/plpgsql.y +++ b/pkg/sql/plpgsql/parser/plpgsql.y @@ -2,6 +2,7 @@ package parser import ( + "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/scanner" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sem/plpgsqltree" @@ -124,7 +125,7 @@ func (u *plpgsqlSymUnion) open() *plpgsqltree.Open { func (u *plpgsqlSymUnion) expr() plpgsqltree.Expr { if u.val == nil { - return nil + return nil } return u.val.(plpgsqltree.Expr) } @@ -133,14 +134,6 @@ func (u *plpgsqlSymUnion) exprs() []plpgsqltree.Expr { return u.val.([]plpgsqltree.Expr) } -func (u *plpgsqlSymUnion) declaration() *plpgsqltree.Declaration { - return u.val.(*plpgsqltree.Declaration) -} - -func (u *plpgsqlSymUnion) declarations() []plpgsqltree.Declaration { - return u.val.([]plpgsqltree.Declaration) -} - func (u *plpgsqlSymUnion) raiseOption() *plpgsqltree.RaiseOption { return u.val.(*plpgsqltree.RaiseOption) } @@ -166,6 +159,14 @@ func (u *plpgsqlSymUnion) conditions() []plpgsqltree.Condition { return u.val.([]plpgsqltree.Condition) } +func (u *plpgsqlSymUnion) cursorScrollOption() tree.CursorScrollOption { + return u.val.(tree.CursorScrollOption) +} + +func (u *plpgsqlSymUnion) sqlStatement() tree.Statement { + return u.val.(tree.Statement) +} + %} /* * Basic non-keyword token types. These are hard-wired into the core lexer. @@ -315,7 +316,6 @@ func (u *plpgsqlSymUnion) conditions() []plpgsqltree.Condition { %type decl_datatype %type decl_collate -%type <*plpgsqltree.Open> open_stmt_processor %type expr_until_semi expr_until_paren %type expr_until_then expr_until_loop opt_expr_until_when %type opt_exitcond @@ -340,8 +340,8 @@ func (u *plpgsqlSymUnion) conditions() []plpgsqltree.Condition { %type stmt_commit stmt_rollback %type stmt_case stmt_foreach_a -%type <*plpgsqltree.Declaration> decl_stmt decl_statement -%type <[]plpgsqltree.Declaration> decl_sect opt_decl_stmts decl_stmts +%type decl_stmt decl_statement +%type <[]plpgsqltree.Statement> decl_sect opt_decl_stmts decl_stmts %type <[]plpgsqltree.Exception> exception_sect proc_exceptions %type <*plpgsqltree.Exception> proc_exception @@ -362,9 +362,7 @@ func (u *plpgsqlSymUnion) conditions() []plpgsqltree.Condition { %type format_expr %type <[]plpgsqltree.Expr> opt_format_exprs format_exprs -%type opt_scrollable - -%type <*plpgsqltree.Fetch> opt_fetch_direction +%type opt_scrollable %type <*tree.NumVal> opt_transaction_chain @@ -385,7 +383,7 @@ pl_block: opt_block_label decl_sect BEGIN proc_sect exception_sect END opt_label { $$.val = &plpgsqltree.Block{ Label: $1, - Decls: $2.declarations(), + Decls: $2.statements(), Body: $4.statements(), Exceptions: $5.exceptions(), } @@ -394,54 +392,46 @@ pl_block: opt_block_label decl_sect BEGIN proc_sect exception_sect END opt_label decl_sect: DECLARE opt_decl_stmts { - $$.val = $2.declarations() + $$.val = $2.statements() } | /* EMPTY */ { // Use a nil slice to indicate DECLARE was not used. - $$.val = []plpgsqltree.Declaration(nil) + $$.val = []plpgsqltree.Statement(nil) } ; opt_decl_stmts: decl_stmts { - $$.val = $1.declarations() + $$.val = $1.statements() } | /* EMPTY */ { - $$.val = []plpgsqltree.Declaration{} + $$.val = []plpgsqltree.Statement{} } ; decl_stmts: decl_stmts decl_stmt { - decs := $1.declarations() - dec := $2.declaration() - if dec == nil { - $$.val = decs - } else { - $$.val = append(decs, *dec) - } + decs := $1.statements() + dec := $2.statement() + $$.val = append(decs, dec) } | decl_stmt { - dec := $1.declaration() - if dec == nil { - $$.val = []plpgsqltree.Declaration{} - } else { - $$.val = []plpgsqltree.Declaration{*dec} - } + dec := $1.statement() + $$.val = []plpgsqltree.Statement{dec} } ; decl_stmt : decl_statement { - $$.val = $1.declaration() + $$.val = $1.statement() } | DECLARE { // This is to allow useless extra "DECLARE" keywords in the declare section. - $$.val = (*plpgsqltree.Declaration)(nil) + $$.val = (plpgsqltree.Statement)(nil) } // TODO(chengxiong): turn this block on and throw useful error if user // tries to put the block label just before BEGIN instead of before @@ -466,36 +456,48 @@ decl_statement: decl_varname decl_const decl_datatype decl_collate decl_notnull { return unimplemented(plpgsqllex, "alias for") } -| decl_varname opt_scrollable CURSOR decl_cursor_args decl_is_for decl_cursor_query ';' +| decl_varname opt_scrollable CURSOR decl_cursor_args decl_is_for decl_cursor_query { - return unimplemented(plpgsqllex, "cursor") + $$.val = &plpgsqltree.CursorDeclaration{ + Name: plpgsqltree.Variable($1), + Scroll: $2.cursorScrollOption(), + Query: $6.sqlStatement(), + } } ; opt_scrollable: { - return unimplemented(plpgsqllex, "cursor") + $$.val = tree.UnspecifiedScroll } | NO_SCROLL SCROLL { - return unimplemented(plpgsqllex, "cursor") + $$.val = tree.NoScroll } | SCROLL { - return unimplemented(plpgsqllex, "cursor") + $$.val = tree.Scroll } ; -decl_cursor_query: +decl_cursor_query: expr_until_semi ';' { - plpgsqllex.(*lexer).ReadSqlExpressionStr(';') + stmts, err := parser.Parse($1) + if err != nil { + return setErr(plpgsqllex, err) + } + if len(stmts) != 1 { + return setErr(plpgsqllex, errors.New("expected exactly one SQL statement for cursor")) + } + $$.val = stmts[0].AST } ; -decl_cursor_args: +decl_cursor_args: '(' { + return unimplemented(plpgsqllex, "cursor arguments") } -| '(' decl_cursor_arglist ')' +| /* EMPTY */ { } ; @@ -687,11 +689,17 @@ proc_stmt:pl_block ';' | stmt_getdiag { } | stmt_open - { } + { + $$.val = $1.statement() + } | stmt_fetch - { } + { + $$.val = $1.statement() + } | stmt_move - { } + { + $$.val = $1.statement() + } | stmt_close { $$.val = $1.statement() @@ -1247,35 +1255,54 @@ stmt_dynexecute: EXECUTE } ; -// TODO: change expr_until_semi to process_cursor_before_semi -stmt_open: OPEN IDENT open_stmt_processor ';' +stmt_open: OPEN IDENT ';' { - openCursorStmt := $3.open() - openCursorStmt.CursorName = $2 - $$.val = openCursorStmt + $$.val = &plpgsqltree.Open{CurVar: plpgsqltree.Variable($2)} } -; - -stmt_fetch: FETCH opt_fetch_direction IDENT INTO +| OPEN IDENT opt_scrollable FOR EXECUTE { - return unimplemented(plpgsqllex, "fetch") + return unimplemented(plpgsqllex, "cursor for execute") + } +| OPEN IDENT opt_scrollable FOR expr_until_semi ';' + { + stmts, err := parser.Parse($5) + if err != nil { + return setErr(plpgsqllex, err) + } + if len(stmts) != 1 { + return setErr(plpgsqllex, errors.New("expected exactly one SQL statement for cursor")) + } + $$.val = &plpgsqltree.Open{ + CurVar: plpgsqltree.Variable($2), + Scroll: $3.cursorScrollOption(), + Query: stmts[0].AST, + } } ; -stmt_move: MOVE opt_fetch_direction IDENT ';' +stmt_fetch: FETCH { - return unimplemented(plpgsqllex, "move") + fetch, err := plpgsqllex.(*lexer).MakeFetchOrMoveStmt(false) + if err != nil { + return setErr(plpgsqllex, err) + } + $$.val = fetch } ; -opt_fetch_direction: +stmt_move: MOVE { - return unimplemented(plpgsqllex, "fetch direction") + move, err := plpgsqllex.(*lexer).MakeFetchOrMoveStmt(true) + if err != nil { + return setErr(plpgsqllex, err) + } + $$.val = move } +; -stmt_close: CLOSE cursor_variable ';' +stmt_close: CLOSE IDENT ';' { - $$.val = &plpgsqltree.Close{} + $$.val = &plpgsqltree.Close{CurVar: plpgsqltree.Variable($2)} } ; @@ -1305,12 +1332,6 @@ AND CHAIN | /* EMPTY */ { } -cursor_variable: IDENT - { - unimplemented(plpgsqllex, "cursor variable") - } -; - exception_sect: /* EMPTY */ { $$.val = []plpgsqltree.Exception(nil) @@ -1364,11 +1385,6 @@ proc_condition: any_identifier } ; -open_stmt_processor: - { - $$.val = plpgsqllex.(*lexer).ProcessForOpenCursor(true) - } - expr_until_semi: { $$ = plpgsqllex.(*lexer).ReadSqlExpressionStr(';') diff --git a/pkg/sql/plpgsql/parser/testdata/decl_header b/pkg/sql/plpgsql/parser/testdata/decl_header index 086024b6f5fe..3cf7d9132b35 100644 --- a/pkg/sql/plpgsql/parser/testdata/decl_header +++ b/pkg/sql/plpgsql/parser/testdata/decl_header @@ -40,10 +40,21 @@ END ---- at or near ";": syntax error: unimplemented: this syntax +parse +DECLARE + var1 CURSOR FOR SELECT * FROM t1 WHERE id = arg1; +BEGIN +END +---- +DECLARE +var1 CURSOR FOR SELECT * FROM t1 WHERE id = arg1; +BEGIN +END + parse DECLARE var1 NO SCROLL CURSOR (arg1 INTEGER) FOR SELECT * FROM t1 WHERE id = arg1; BEGIN END ---- -at or near "scroll": syntax error: unimplemented: this syntax +at or near "(": syntax error: unimplemented: this syntax diff --git a/pkg/sql/plpgsql/parser/testdata/stmt_close b/pkg/sql/plpgsql/parser/testdata/stmt_close index b61e63fc98bf..90edb47193a7 100644 --- a/pkg/sql/plpgsql/parser/testdata/stmt_close +++ b/pkg/sql/plpgsql/parser/testdata/stmt_close @@ -6,5 +6,5 @@ END ---- DECLARE BEGIN -CLOSE a cursor +CLOSE some_cursor; END diff --git a/pkg/sql/plpgsql/parser/testdata/stmt_fetch_move b/pkg/sql/plpgsql/parser/testdata/stmt_fetch_move index ac26c7b4b811..bceef82ba3b2 100644 --- a/pkg/sql/plpgsql/parser/testdata/stmt_fetch_move +++ b/pkg/sql/plpgsql/parser/testdata/stmt_fetch_move @@ -4,7 +4,10 @@ BEGIN MOVE NEXT FROM emp_cur; END ---- -at or near "move": syntax error: unimplemented: this syntax +DECLARE +BEGIN +MOVE 1 FROM emp_cur; +END parse DECLARE @@ -12,7 +15,10 @@ BEGIN MOVE PRIOR FROM var; END ---- -at or near "move": syntax error: unimplemented: this syntax +DECLARE +BEGIN +MOVE -1 FROM var; +END parse DECLARE @@ -20,7 +26,10 @@ BEGIN FETCH NEXT FROM emp_cur INTO x,y; END ---- -at or near "fetch": syntax error: unimplemented: this syntax +DECLARE +BEGIN +FETCH 1 FROM emp_cur INTO x, y; +END parse DECLARE @@ -28,7 +37,10 @@ BEGIN FETCH emp_cur INTO x,y; END ---- -at or near "fetch": syntax error: unimplemented: this syntax +DECLARE +BEGIN +FETCH 1 FROM emp_cur INTO x, y; +END parse DECLARE @@ -36,4 +48,284 @@ BEGIN FETCH ABSOLUTE 2 FROM emp_cur INTO x,y; END ---- -at or near "fetch": syntax error: unimplemented: this syntax +DECLARE +BEGIN +FETCH ABSOLUTE 2 FROM emp_cur INTO x, y; +END + +parse +DECLARE +BEGIN +FETCH emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH 1 FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH emp_cur INTO; +END +---- +at or near "into": syntax error: expected INTO target + +parse +DECLARE +BEGIN +FETCH emp_cur; +END +---- +at or near "emp_cur": syntax error: invalid syntax for FETCH + +parse +DECLARE +BEGIN +MOVE NEXT FROM emp_cur INTO x, y; +END +---- +at or near ";": at or near "x": syntax error + +parse +DECLARE +BEGIN +MOVE NEXT FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE 1 FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE PRIOR FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE -1 FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE FIRST FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE FIRST FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE LAST FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE LAST FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE ABSOLUTE 5 FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE ABSOLUTE 5 FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE FIRST FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE FIRST FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE RELATIVE 3 FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE RELATIVE 3 FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE FORWARD 3 FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE 3 FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE BACKWARD 3 FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE -3 FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE FORWARD ALL FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE ALL FROM emp_cur; +END + +parse +DECLARE +BEGIN +MOVE BACKWARD ALL FROM emp_cur; +END +---- +DECLARE +BEGIN +MOVE BACKWARD ALL FROM emp_cur; +END + +parse +DECLARE +BEGIN +FETCH NEXT FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH 1 FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH PRIOR FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH -1 FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH FIRST FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH FIRST FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH LAST FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH LAST FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH ABSOLUTE 5 FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH ABSOLUTE 5 FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH FIRST FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH FIRST FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH RELATIVE 3 FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH RELATIVE 3 FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH FORWARD 3 FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH 3 FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH BACKWARD 3 FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH -3 FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH FORWARD ALL FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH ALL FROM emp_cur INTO x; +END + +parse +DECLARE +BEGIN +FETCH BACKWARD ALL FROM emp_cur INTO x; +END +---- +DECLARE +BEGIN +FETCH BACKWARD ALL FROM emp_cur INTO x; +END diff --git a/pkg/sql/plpgsql/parser/testdata/stmt_open b/pkg/sql/plpgsql/parser/testdata/stmt_open index 73a061c8e2a3..d91b472f6352 100644 --- a/pkg/sql/plpgsql/parser/testdata/stmt_open +++ b/pkg/sql/plpgsql/parser/testdata/stmt_open @@ -1,22 +1,51 @@ parse DECLARE BEGIN -OPEN curs1 NO SCROLL FOR SELECT * FROM foo WHERE key = mykey; +OPEN curs1; END ---- DECLARE BEGIN -OPEN curs1 NO SCROLL FOR SELECT * FROM foo WHERE key = mykey +OPEN curs1; END +parse +DECLARE +BEGIN +OPEN curs1 FOR SELECT * FROM foo WHERE key = mykey; +END +---- +DECLARE +BEGIN +OPEN curs1 FOR SELECT * FROM foo WHERE key = mykey; +END parse DECLARE BEGIN -OPEN curs2 SCROLL FOR EXECUTE SELECT $1, $2 FROM foo WHERE key = mykey USING hello, jojo; +OPEN curs1 SCROLL FOR SELECT * FROM foo WHERE key = mykey; +END +---- +DECLARE +BEGIN +OPEN curs1 SCROLL FOR SELECT * FROM foo WHERE key = mykey; +END + +parse +DECLARE +BEGIN +OPEN curs1 NO SCROLL FOR SELECT * FROM foo WHERE key = mykey; END ---- DECLARE BEGIN -OPEN curs2 SCROLL FOR EXECUTE SELECT $1, $2 FROM foo WHERE key = mykey USING [hello jojo] +OPEN curs1 NO SCROLL FOR SELECT * FROM foo WHERE key = mykey; +END + +parse +DECLARE +BEGIN +OPEN curs2 SCROLL FOR EXECUTE SELECT $1, $2 FROM foo WHERE key = mykey USING hello, jojo; END +---- +at or near "execute": syntax error: unimplemented: this syntax diff --git a/pkg/sql/routine.go b/pkg/sql/routine.go index fc5a7c2f5e49..e8602767195f 100644 --- a/pkg/sql/routine.go +++ b/pkg/sql/routine.go @@ -15,6 +15,8 @@ import ( "strconv" "github.com/cockroachdb/cockroach/pkg/kv" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/colinfo" + "github.com/cockroachdb/cockroach/pkg/sql/isql" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/plpgsql" @@ -22,6 +24,8 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/util" + "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/tracing" "github.com/cockroachdb/errors" ) @@ -226,23 +230,30 @@ func (g *routineGenerator) startInternal(ctx context.Context, txn *kv.Txn) (err stmtIdx := 0 ef := newExecFactory(ctx, g.p) rrw := NewRowResultWriter(&g.rch) + var cursorHelper *plpgsqlCursorHelper err = g.expr.ForEachPlan(ctx, ef, g.args, func(plan tree.RoutinePlan, isFinalPlan bool) error { stmtIdx++ opName := "udf-stmt-" + g.expr.Name + "-" + strconv.Itoa(stmtIdx) ctx, sp := tracing.ChildSpan(ctx, opName) defer sp.Finish() - // If this is the last statement and it is not a procedure, use the - // rowResultWriter created above. Otherwise, use a rowResultWriter that - // drops all rows added to it. - // - // We can use a droppingResultWriter for all statements in a procedure - // because we do not yet allow OUT or INOUT parameters, so a procedure - // never returns values. var w rowResultWriter + openCursor := stmtIdx == 1 && g.expr.CursorDeclaration != nil if isFinalPlan && !g.expr.Procedure { + // The result of this statement is the routine's output. This is never the + // case for a procedure, which does not output any rows (since we do not + // yet support OUT or INOUT parameters). w = rrw + } else if openCursor { + // The result of the first statement will be used to open a SQL cursor. + cursorHelper, err = g.newCursorHelper(ctx, plan.(*planComponents)) + if err != nil { + return err + } + w = NewRowResultWriter(&cursorHelper.container) } else { + // The result of this statement is not needed. Use a rowResultWriter that + // drops all rows added to it. w = &droppingResultWriter{} } @@ -259,9 +270,15 @@ func (g *routineGenerator) startInternal(ctx context.Context, txn *kv.Txn) (err if err != nil { return err } + if openCursor { + return cursorHelper.createCursor(g.p) + } return nil }) if err != nil { + if cursorHelper != nil { + err = errors.CombineErrors(err, cursorHelper.Close()) + } return g.handleException(ctx, err) } @@ -359,3 +376,109 @@ func (d *droppingResultWriter) SetError(err error) { func (d *droppingResultWriter) Err() error { return d.err } + +func (g *routineGenerator) newCursorHelper( + ctx context.Context, plan *planComponents, +) (*plpgsqlCursorHelper, error) { + open := g.expr.CursorDeclaration + if open.NameArgIdx < 0 || open.NameArgIdx >= len(g.args) { + panic(errors.AssertionFailedf("unexpected name argument index: %d", open.NameArgIdx)) + } + if g.args[open.NameArgIdx] == tree.DNull { + return nil, unimplemented.New("unnamed cursor", + "opening an unnamed cursor is not yet supported", + ) + } + planCols := plan.main.planColumns() + cursorHelper := &plpgsqlCursorHelper{ + ctx: ctx, + cursorName: tree.Name(tree.MustBeDString(g.args[open.NameArgIdx])), + resultCols: make(colinfo.ResultColumns, len(planCols)), + } + copy(cursorHelper.resultCols, planCols) + cursorHelper.container.Init( + ctx, + getTypesFromResultColumns(planCols), + g.p.ExtendedEvalContextCopy(), + "routine_open_cursor", /* opName */ + ) + return cursorHelper, nil +} + +// plpgsqlCursorHelper wraps a row container in order to feed the results of +// executing a SQL statement to a SQL cursor. Note that the SQL statement is not +// lazily executed; its entire result is written to the container. +// TODO(drewk): while the row container can spill to disk, we should default to +// lazy execution for cursors for performance reasons. +type plpgsqlCursorHelper struct { + ctx context.Context + cursorName tree.Name + cursorSql string + + // Fields related to implementing the isql.Rows interface. + container rowContainerHelper + iter *rowContainerIterator + resultCols colinfo.ResultColumns + lastRow tree.Datums + lastErr error + rowsAffected int +} + +func (h *plpgsqlCursorHelper) createCursor(p *planner) error { + h.iter = newRowContainerIterator(h.ctx, h.container) + cursor := &sqlCursor{ + Rows: h, + readSeqNum: p.txn.GetReadSeqNum(), + txn: p.txn, + statement: h.cursorSql, + created: timeutil.Now(), + } + if err := p.checkIfCursorExists(h.cursorName); err != nil { + return err + } + return p.sqlCursors.addCursor(h.cursorName, cursor) +} + +var _ isql.Rows = &plpgsqlCursorHelper{} + +// Next implements the isql.Rows interface. +func (h *plpgsqlCursorHelper) Next(_ context.Context) (bool, error) { + h.lastRow, h.lastErr = h.iter.Next() + if h.lastErr != nil { + return false, h.lastErr + } + if h.lastRow != nil { + h.rowsAffected++ + } + return h.lastRow != nil, nil +} + +// Cur implements the isql.Rows interface. +func (h *plpgsqlCursorHelper) Cur() tree.Datums { + return h.lastRow +} + +// RowsAffected implements the isql.Rows interface. +func (h *plpgsqlCursorHelper) RowsAffected() int { + return h.rowsAffected +} + +// Close implements the isql.Rows interface. +func (h *plpgsqlCursorHelper) Close() error { + if h.iter != nil { + h.iter.Close() + h.iter = nil + } + h.container.Close(h.ctx) + return h.lastErr +} + +// Types implements the isql.Rows interface. +func (h *plpgsqlCursorHelper) Types() colinfo.ResultColumns { + return h.resultCols +} + +// HasResults implements the isql.Rows interface. +func (h *plpgsqlCursorHelper) HasResults() bool { + return h.lastRow != nil +} diff --git a/pkg/sql/sem/plpgsqltree/constants.go b/pkg/sql/sem/plpgsqltree/constants.go index 8c8c389406cf..67ebc342990c 100644 --- a/pkg/sql/sem/plpgsqltree/constants.go +++ b/pkg/sql/sem/plpgsqltree/constants.go @@ -79,102 +79,4 @@ func (k GetDiagnosticsKind) String() string { return "SCHEMA_NAME" } panic(errors.AssertionFailedf("unknown diagnostics kind")) - -} - -// FetchDirection represents the direction clause passed into a fetch statement. -type FetchDirection int - -// CursorOption represents a cursor option, which describes how a cursor will -// behave. -type CursorOption uint32 - -const ( - // CursorOptionNone - CursorOptionNone CursorOption = iota - // CursorOptionBinary describes cursors that return data in binary form. - CursorOptionBinary - // CursorOptionScroll describes cursors that can retrieve rows in - // non-sequential fashion. - CursorOptionScroll - // CursorOptionNoScroll describes cursors that can not retrieve rows in - // non-sequential fashion. - CursorOptionNoScroll - // CursorOptionInsensitive describes cursors that can't see changes to - // done to data in same txn. - CursorOptionInsensitive - // CursorOPtionAsensitive describes cursors that may be able to see - // changes to done to data in same txn. - CursorOPtionAsensitive - // CursorOptionHold describes cursors that can be used after a txn that it - // was created in commits. - CursorOptionHold - // CursorOptionFastPlan describes cursors that can not be used after a txn - // that it was created in commits. - CursorOptionFastPlan - // CursorOptionGenericPlan describes cursors that uses a generic plan. - CursorOptionGenericPlan - // CursorOptionCustomPlan describes cursors that uses a custom plan. - CursorOptionCustomPlan - // CursorOptionParallelOK describes cursors that allows parallel queries. - CursorOptionParallelOK -) - -// String implements the fmt.Stringer interface. -func (o CursorOption) String() string { - switch o { - case CursorOptionNoScroll: - return "NO SCROLL" - case CursorOptionScroll: - return "SCROLL" - case CursorOptionFastPlan: - return "" - // TODO(jane): implement string representation for other opts. - default: - return "NOT_IMPLEMENTED_OPT" - } -} - -// Mask returns the bitmask for a given cursor option. -func (o CursorOption) Mask() uint32 { - return 1 << o -} - -// IsSetIn returns true if this cursor option is set in the supplied bitfield. -func (o CursorOption) IsSetIn(bits uint32) bool { - return bits&o.Mask() != 0 -} - -type cursorOptionList []CursorOption - -// ToBitField returns the bitfield representation of a list of cursor options. -func (ol cursorOptionList) ToBitField() uint32 { - var ret uint32 - for _, o := range ol { - ret |= o.Mask() - } - return ret -} - -// OptListFromBitField returns a list of cursor option to be printed. -func OptListFromBitField(m uint32) cursorOptionList { - ret := cursorOptionList{} - opts := []CursorOption{ - CursorOptionBinary, - CursorOptionScroll, - CursorOptionNoScroll, - CursorOptionInsensitive, - CursorOPtionAsensitive, - CursorOptionHold, - CursorOptionFastPlan, - CursorOptionGenericPlan, - CursorOptionCustomPlan, - CursorOptionParallelOK, - } - for _, opt := range opts { - if opt.IsSetIn(m) { - ret = append(ret, opt) - } - } - return ret } diff --git a/pkg/sql/sem/plpgsqltree/statements.go b/pkg/sql/sem/plpgsqltree/statements.go index 68da4883dce2..02dc14c36787 100644 --- a/pkg/sql/sem/plpgsqltree/statements.go +++ b/pkg/sql/sem/plpgsqltree/statements.go @@ -12,6 +12,7 @@ package plpgsqltree import ( "fmt" + "strconv" "strings" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" @@ -57,7 +58,7 @@ func (s *StatementImpl) plpgsqlStmt() {} type Block struct { StatementImpl Label string - Decls []Declaration + Decls []Statement Body []Statement Exceptions []Exception } @@ -134,6 +135,34 @@ func (s *Declaration) WalkStmt(visitor StatementVisitor) { visitor.Visit(s) } +type CursorDeclaration struct { + StatementImpl + Name Variable + Scroll tree.CursorScrollOption + Query tree.Statement +} + +func (s *CursorDeclaration) Format(ctx *tree.FmtCtx) { + ctx.WriteString(string(s.Name)) + switch s.Scroll { + case tree.Scroll: + ctx.WriteString(" SCROLL") + case tree.NoScroll: + ctx.WriteString(" NO SCROLL") + } + ctx.WriteString(" CURSOR FOR ") + s.Query.Format(ctx) + ctx.WriteString(";\n") +} + +func (s *CursorDeclaration) PlpgSQLStatementTag() string { + return "decl_cursor_stmt" +} + +func (s *CursorDeclaration) WalkStmt(visitor StatementVisitor) { + visitor.Visit(s) +} + // stmt_assign type Assignment struct { Statement @@ -850,50 +879,25 @@ func (s *GetDiagnostics) WalkStmt(visitor StatementVisitor) { // stmt_open type Open struct { StatementImpl - CurVar int // TODO(drewk): this could just a Variable - CursorOptions uint32 - // TODO(jane): This is temporary and we should remove it and use CurVar. - CursorName string - WithExplicitExpr bool - // TODO(jane): Should be Expr - ArgQuery string - // TODO(jane): Should be Expr - Query string - // TODO(jane): Should be Expr - DynamicQuery string - // TODO(jane): Should be []Expr - Params []string + CurVar Variable + Scroll tree.CursorScrollOption + Query tree.Statement } func (s *Open) Format(ctx *tree.FmtCtx) { - ctx.WriteString( - fmt.Sprintf( - "OPEN %s ", - s.CursorName, - )) - - opts := OptListFromBitField(s.CursorOptions) - for _, opt := range opts { - if opt.String() != "" { - ctx.WriteString(fmt.Sprintf("%s ", opt.String())) - } + ctx.WriteString("OPEN ") + s.CurVar.Format(ctx) + switch s.Scroll { + case tree.Scroll: + ctx.WriteString(" SCROLL") + case tree.NoScroll: + ctx.WriteString(" NO SCROLL") } - if !s.WithExplicitExpr { - ctx.WriteString("FOR ") - if s.DynamicQuery != "" { - // TODO(drewk): Make sure placeholders are properly printed - ctx.WriteString(fmt.Sprintf("EXECUTE %s ", s.DynamicQuery)) - if len(s.Params) != 0 { - // TODO(drewk): Dont print instances of multiple params with brackets `[...]` - ctx.WriteString(fmt.Sprintf("USING %s", s.Params)) - } - } else { - ctx.WriteString(s.Query) - } - } else { - ctx.WriteString(s.ArgQuery) + if s.Query != nil { + ctx.WriteString(" FOR ") + s.Query.Format(ctx) } - ctx.WriteString("\n") + ctx.WriteString(";\n") } func (s *Open) PlpgSQLStatementTag() string { @@ -908,16 +912,37 @@ func (s *Open) WalkStmt(visitor StatementVisitor) { // stmt_move (where IsMove = true) type Fetch struct { StatementImpl - Target Variable - CurVar int // TODO(drewk): this could just a Variable - Direction FetchDirection - HowMany int64 - Expr Expr - IsMove bool - ReturnsMultiRows bool + Cursor tree.CursorStmt + Target []Variable + IsMove bool } func (s *Fetch) Format(ctx *tree.FmtCtx) { + if s.IsMove { + ctx.WriteString("MOVE ") + } else { + ctx.WriteString("FETCH ") + } + if dir := s.Cursor.FetchType.String(); dir != "" { + ctx.WriteString(dir) + ctx.WriteString(" ") + } + if s.Cursor.FetchType.HasCount() { + ctx.WriteString(strconv.Itoa(int(s.Cursor.Count))) + ctx.WriteString(" ") + } + ctx.WriteString("FROM ") + s.Cursor.Name.Format(ctx) + if s.Target != nil { + ctx.WriteString(" INTO ") + for i := range s.Target { + if i > 0 { + ctx.WriteString(", ") + } + s.Target[i].Format(ctx) + } + } + ctx.WriteString(";\n") } func (s *Fetch) PlpgSQLStatementTag() string { @@ -934,13 +959,13 @@ func (s *Fetch) WalkStmt(visitor StatementVisitor) { // stmt_close type Close struct { StatementImpl - CurVar int // TODO(drewk): this could just a Variable + CurVar Variable } func (s *Close) Format(ctx *tree.FmtCtx) { - // TODO(drewk): Pretty- Print the cursor identifier - ctx.WriteString("CLOSE a cursor\n") - + ctx.WriteString("CLOSE ") + s.CurVar.Format(ctx) + ctx.WriteString(";\n") } func (s *Close) PlpgSQLStatementTag() string { diff --git a/pkg/sql/sem/tree/routine.go b/pkg/sql/sem/tree/routine.go index 83d1a1c6ebb3..2e842cabdff7 100644 --- a/pkg/sql/sem/tree/routine.go +++ b/pkg/sql/sem/tree/routine.go @@ -126,6 +126,10 @@ type RoutineExpr struct { // ExceptionHandler holds the information needed to handle errors if an // exception block was defined. ExceptionHandler *RoutineExceptionHandler + + // CursorDeclaration contains the information needed to open a SQL cursor with + // the result of the *first* body statement. It may be unset. + CursorDeclaration *RoutineOpenCursor } // NewTypedRoutineExpr returns a new RoutineExpr that is well-typed. @@ -141,6 +145,7 @@ func NewTypedRoutineExpr( tailCall bool, procedure bool, exceptionHandler *RoutineExceptionHandler, + cursorDeclaration *RoutineOpenCursor, ) *RoutineExpr { return &RoutineExpr{ Args: args, @@ -154,6 +159,7 @@ func NewTypedRoutineExpr( TailCall: tailCall, Procedure: procedure, ExceptionHandler: exceptionHandler, + CursorDeclaration: cursorDeclaration, } } @@ -191,3 +197,19 @@ type RoutineExceptionHandler struct { // Actions contains a routine to handle each error code. Actions []*RoutineExpr } + +// RoutineOpenCursor stores the information needed to correctly open a cursor +// with the output of a routine. +type RoutineOpenCursor struct { + // NameArgIdx is the index of the routine argument that contains the name of + // the cursor that will be created. + NameArgIdx int + + // Scroll is the scroll option for the cursor, if one was specified. The other + // cursor options are not valid in PLpgSQL. + Scroll CursorScrollOption + + // CursorSQL is a formatted string used to associate the original SQL + // statement with the cursor. + CursorSQL string +} diff --git a/pkg/sql/sql_cursor.go b/pkg/sql/sql_cursor.go index 1a722d6ed03f..5c5166c86aae 100644 --- a/pkg/sql/sql_cursor.go +++ b/pkg/sql/sql_cursor.go @@ -59,12 +59,8 @@ func (p *planner) DeclareCursor(ctx context.Context, s *tree.DeclareCursor) (pla sd.StmtTimeout = 0 } ie := p.ExecCfg().InternalDB.NewInternalExecutor(sd) - if cursor := p.sqlCursors.getCursor(s.Name); cursor != nil { - return nil, pgerror.Newf(pgcode.DuplicateCursor, "cursor %q already exists", s.Name) - } - - if p.extendedEvalCtx.PreparedStatementState.HasPortal(string(s.Name)) { - return nil, pgerror.Newf(pgcode.DuplicateCursor, "cursor %q already exists as portal", s.Name) + if err := p.checkIfCursorExists(s.Name); err != nil { + return nil, err } // Try to plan the cursor query to make sure that it's valid. @@ -122,6 +118,18 @@ func (p *planner) DeclareCursor(ctx context.Context, s *tree.DeclareCursor) (pla }, nil } +// checkIfCursorExists checks whether a cursor or portal with the given name +// already exists, and returns an error if one does. +func (p *planner) checkIfCursorExists(name tree.Name) error { + if cursor := p.sqlCursors.getCursor(name); cursor != nil { + return pgerror.Newf(pgcode.DuplicateCursor, "cursor %q already exists", name) + } + if p.extendedEvalCtx.PreparedStatementState.HasPortal(string(name)) { + return pgerror.Newf(pgcode.DuplicateCursor, "cursor %q already exists as portal", name) + } + return nil +} + var errBackwardScan = pgerror.Newf(pgcode.ObjectNotInPrerequisiteState, "cursor can only scan forward") // FetchCursor implements the FETCH and MOVE statements.