From 6a056b47dbc953a345f0dc6fc619e00e4ee119aa Mon Sep 17 00:00:00 2001 From: Jerome Gravel-Niquet Date: Sun, 22 Oct 2023 08:08:02 -0400 Subject: [PATCH 01/12] pg exploration --- Cargo.lock | 331 +++- Cargo.toml | 2 + crates/corro-agent/Cargo.toml | 1 + crates/corro-agent/src/agent.rs | 6 +- crates/corro-agent/src/api/mod.rs | 1 + crates/corro-agent/src/api/pg.rs | 1 + crates/corro-agent/src/api/public/mod.rs | 544 +++++- crates/corro-agent/src/api/public/pubsub.rs | 6 +- crates/corro-pg/Cargo.toml | 26 + crates/corro-pg/src/lib.rs | 260 +++ crates/corro-pg/src/proto.rs | 650 ++++++++ crates/corro-pg/src/proto_ext.rs | 163 ++ crates/corro-pg/src/sql_state.rs | 1668 +++++++++++++++++++ crates/corro-types/Cargo.toml | 3 + crates/corro-types/src/agent.rs | 27 +- crates/corro-types/src/config.rs | 9 + crates/corro-types/src/http.rs | 181 ++ crates/corro-types/src/lib.rs | 1 + crates/corro-types/src/schema.rs | 12 +- 19 files changed, 3856 insertions(+), 36 deletions(-) create mode 100644 crates/corro-agent/src/api/pg.rs create mode 100644 crates/corro-pg/Cargo.toml create mode 100644 crates/corro-pg/src/lib.rs create mode 100644 crates/corro-pg/src/proto.rs create mode 100644 crates/corro-pg/src/proto_ext.rs create mode 100644 crates/corro-pg/src/sql_state.rs create mode 100644 crates/corro-types/src/http.rs diff --git a/Cargo.lock b/Cargo.lock index 57fe977d..e13e2df7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -50,6 +50,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + [[package]] name = "android_system_properties" version = "0.1.5" @@ -319,6 +325,22 @@ version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a" +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + +[[package]] +name = "bcder" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf16bec990f8ea25cab661199904ef452fcf11f565c404ce6cffbdf3f8cbbc47" +dependencies = [ + "bytes", + "smallvec", +] + [[package]] name = "beef" version = "0.5.2" @@ -503,15 +525,17 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.24" +version = "0.4.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e3c5919066adf22df73762e50cffcde3a758f2a848b113b586d1f86728b673b" +checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" dependencies = [ + "android-tzdata", "iana-time-zone", - "num-integer", + "js-sys", "num-traits", "serde", - "winapi", + "wasm-bindgen", + "windows-targets 0.48.0", ] [[package]] @@ -617,6 +641,12 @@ dependencies = [ "toml", ] +[[package]] +name = "const-oid" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28c122c3980598d243d63d9a704629a2d748d101f278052ff068be5a4423ab6f" + [[package]] name = "const-random" version = "0.1.15" @@ -728,6 +758,7 @@ dependencies = [ "quoted-string", "rand", "rangemap", + "rmp-serde", "rusqlite", "rustls", "rustls-pemfile", @@ -790,6 +821,30 @@ dependencies = [ "uuid", ] +[[package]] +name = "corro-pg" +version = "0.1.0" +dependencies = [ + "bytes", + "compact_str 0.7.0", + "corro-tests", + "corro-types", + "futures", + "pgwire", + "phf", + "postgres-types", + "rusqlite", + "sqlparser", + "tempfile", + "thiserror", + "time", + "tokio", + "tokio-postgres", + "tokio-util", + "tracing", + "tracing-subscriber", +] + [[package]] name = "corro-speedy" version = "0.8.7" @@ -862,15 +917,18 @@ dependencies = [ "fallible-iterator", "foca", "futures", + "hyper", "indexmap", "itertools", "metrics", "once_cell", "opentelemetry", "parking_lot", + "pin-project-lite", "rand", "rangemap", "rcgen", + "rmp-serde", "rusqlite", "serde", "serde_json", @@ -1166,6 +1224,16 @@ dependencies = [ "thiserror", ] +[[package]] +name = "der" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c" +dependencies = [ + "const-oid", + "zeroize", +] + [[package]] name = "der-parser" version = "8.2.0" @@ -1180,6 +1248,17 @@ dependencies = [ "rusticata-macros", ] +[[package]] +name = "derive-new" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3418329ca0ad70234b9735dc4ceed10af4df60eff9c8e7b06cb5e520d92c3535" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "derive_more" version = "0.99.17" @@ -1213,6 +1292,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] [[package]] @@ -1361,6 +1441,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "finl_unicode" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fcfdc7a0362c9f4444381a9e697c79d435fe65b52a37466fc2c1184cee9edc6" + [[package]] name = "fnv" version = "1.0.7" @@ -1509,6 +1595,18 @@ dependencies = [ "wasi", ] +[[package]] +name = "getset" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e45727250e75cc04ff2846a66397da8ef2b3db8e40e0cef4df67950a07621eb9" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "gimli" version = "0.27.2" @@ -1628,6 +1726,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "hostname" version = "0.3.1" @@ -2089,6 +2196,22 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40" +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest", +] + +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "memchr" version = "2.5.0" @@ -2495,6 +2618,12 @@ dependencies = [ "windows-sys 0.45.0", ] +[[package]] +name = "paste" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" + [[package]] name = "pathdiff" version = "0.2.1" @@ -2517,6 +2646,34 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" +[[package]] +name = "pgwire" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06d04982366efd653d4365175426acbabd55efb07231869e92b9e1f5b3faf7df" +dependencies = [ + "async-trait", + "base64 0.21.0", + "bytes", + "chrono", + "derive-new", + "futures", + "getset", + "hex", + "log", + "md5", + "postgres-types", + "rand", + "ring", + "stringprep", + "thiserror", + "time", + "tokio", + "tokio-rustls", + "tokio-util", + "x509-certificate", +] + [[package]] name = "phf" version = "0.11.1" @@ -2600,6 +2757,37 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f602a0d1e09a48e4f8e8b4d4042e32807c3676da31f2ecabeac9f96226ec6c45" +[[package]] +name = "postgres-protocol" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49b6c5ef183cd3ab4ba005f1ca64c21e8bd97ce4699cfea9e8d9a2c4958ca520" +dependencies = [ + "base64 0.21.0", + "byteorder", + "bytes", + "fallible-iterator", + "hmac", + "md-5", + "memchr", + "rand", + "sha2", + "stringprep", +] + +[[package]] +name = "postgres-types" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d2234cdee9408b523530a9b6d2d6b373d1db34f6a8e51dc03ded1828d7fb67c" +dependencies = [ + "bytes", + "chrono", + "fallible-iterator", + "postgres-protocol", + "time", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -2975,6 +3163,28 @@ dependencies = [ "winapi", ] +[[package]] +name = "rmp" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f9860a6cc38ed1da53456442089b4dfa35e7cedaa326df63017af88385e6b20" +dependencies = [ + "byteorder", + "num-traits", + "paste", +] + +[[package]] +name = "rmp-serde" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bffea85eea980d8a74453e5d02a8d93028f3c34725de143085a844ebe953258a" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + [[package]] name = "rusqlite" version = "0.29.0" @@ -3260,6 +3470,17 @@ dependencies = [ "digest", ] +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.4" @@ -3288,6 +3509,12 @@ dependencies = [ "libc", ] +[[package]] +name = "signature" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e1788eed21689f9cf370582dfc467ef36ed9c707f073528ddafa8d83e3b8500" + [[package]] name = "siphasher" version = "0.3.10" @@ -3388,6 +3615,16 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "spki" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1e996ef02c474957d681f1b05213dfb0abab947b446a62d37770b23500184a" +dependencies = [ + "base64ct", + "der", +] + [[package]] name = "sqlite-pool" version = "0.1.0" @@ -3427,12 +3664,32 @@ dependencies = [ "tracing", ] +[[package]] +name = "sqlparser" +version = "0.38.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0272b7bb0a225320170c99901b4b5fb3a4384e255a7f2cc228f61e2ba3893e75" +dependencies = [ + "log", +] + [[package]] name = "static_assertions" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "stringprep" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb41d74e231a107a1b4ee36bd1214b11285b77768d2e3824aedafa988fd36ee6" +dependencies = [ + "finl_unicode", + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "strsim" version = "0.10.0" @@ -3461,6 +3718,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "subtle" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" + [[package]] name = "supports-color" version = "1.3.1" @@ -3731,6 +3994,32 @@ dependencies = [ "tokio-stream", ] +[[package]] +name = "tokio-postgres" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d340244b32d920260ae7448cb72b6e238bddc3d4f7603394e7dd46ed8e48f5b8" +dependencies = [ + "async-trait", + "byteorder", + "bytes", + "fallible-iterator", + "futures-channel", + "futures-util", + "log", + "parking_lot", + "percent-encoding", + "phf", + "pin-project-lite", + "postgres-protocol", + "postgres-types", + "rand", + "socket2 0.5.3", + "tokio", + "tokio-util", + "whoami", +] + [[package]] name = "tokio-rustls" version = "0.24.0" @@ -4321,6 +4610,16 @@ dependencies = [ "untrusted", ] +[[package]] +name = "whoami" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22fc3756b8a9133049b26c7f61ab35416c130e8c09b660f5b3958b446f52cc50" +dependencies = [ + "wasm-bindgen", + "web-sys", +] + [[package]] name = "widestring" version = "0.5.1" @@ -4523,6 +4822,24 @@ dependencies = [ "winapi", ] +[[package]] +name = "x509-certificate" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5d27c90840e84503cf44364de338794d5d5680bdd1da6272d13f80b0769ee0" +dependencies = [ + "bcder", + "bytes", + "chrono", + "der", + "hex", + "pem", + "ring", + "signature", + "spki", + "thiserror", +] + [[package]] name = "x509-parser" version = "0.15.0" @@ -4564,3 +4881,9 @@ checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" dependencies = [ "time", ] + +[[package]] +name = "zeroize" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" diff --git a/Cargo.toml b/Cargo.toml index 108166af..33ae26e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ hyper = { version = "0.14.26", features = ["h2", "http1", "http2", "server", "tc hyper-rustls = { version = "0.24.0", features = ["http2"] } indexmap = { version = "1.9.3", features = ["serde"] } itertools = { version = "0.10.5" } +rmp-serde = { version = "1.1.2" } metrics = "0.21.0" metrics-exporter-prometheus = "0.12.0" once_cell = "1.17.1" @@ -41,6 +42,7 @@ opentelemetry-otlp = { version = "0.13.0" } opentelemetry-semantic-conventions = { version = "0.12.0" } parking_lot = { version = "0.12.1" } pin-project-lite = "0.2.9" +polonius-the-crab = { version = "0.4.1" } quinn = "0.10.2" quinn-proto = "0.10.5" quinn-plaintext = "0.1.0" diff --git a/crates/corro-agent/Cargo.toml b/crates/corro-agent/Cargo.toml index 1e6e2561..c4c4db20 100644 --- a/crates/corro-agent/Cargo.toml +++ b/crates/corro-agent/Cargo.toml @@ -28,6 +28,7 @@ quinn-plaintext = { workspace = true } quoted-string = { workspace = true } rand = { workspace = true } rangemap = { workspace = true } +rmp-serde = { workspace = true } rusqlite = { workspace = true } rustls = { workspace = true } rustls-pemfile = "*" diff --git a/crates/corro-agent/src/agent.rs b/crates/corro-agent/src/agent.rs index 5042e328..0abfcebb 100644 --- a/crates/corro-agent/src/agent.rs +++ b/crates/corro-agent/src/agent.rs @@ -345,7 +345,7 @@ pub async fn run(agent: Agent, opts: AgentOptions) -> eyre::Result<()> { }; for (id, sql) in rows { - let conn = agent.pool().dedicated().await?; + let conn = block_in_place(|| agent.pool().dedicated())?; let (evt_tx, evt_rx) = channel(512); match Matcher::restore(id, &agent.schema().read(), conn, evt_tx, &sql) { Ok(handle) => { @@ -2321,7 +2321,7 @@ async fn handle_sync(agent: &Agent, transport: &Transport) -> Result<(), SyncCli }; if candidates.is_empty() { - return Err(SyncClientError::NoGoodCandidate); + return Ok(()); } debug!("found {} candidates to synchronize with", candidates.len()); @@ -2351,7 +2351,7 @@ async fn handle_sync(agent: &Agent, transport: &Transport) -> Result<(), SyncCli }; if chosen.is_empty() { - return Err(SyncClientError::NoGoodCandidate); + return Ok(()); } let start = Instant::now(); diff --git a/crates/corro-agent/src/api/mod.rs b/crates/corro-agent/src/api/mod.rs index dbabc9f7..cffa154c 100644 --- a/crates/corro-agent/src/api/mod.rs +++ b/crates/corro-agent/src/api/mod.rs @@ -1,2 +1,3 @@ pub mod peer; +pub mod pg; pub mod public; diff --git a/crates/corro-agent/src/api/pg.rs b/crates/corro-agent/src/api/pg.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/crates/corro-agent/src/api/pg.rs @@ -0,0 +1 @@ + diff --git a/crates/corro-agent/src/api/public/mod.rs b/crates/corro-agent/src/api/public/mod.rs index f4e1df1a..b1983a30 100644 --- a/crates/corro-agent/src/api/public/mod.rs +++ b/crates/corro-agent/src/api/public/mod.rs @@ -1,33 +1,43 @@ use std::{ + collections::HashMap, iter::Peekable, - ops::RangeInclusive, + mem::forget, + ops::{Deref, DerefMut, RangeInclusive}, time::{Duration, Instant}, }; -use axum::{response::IntoResponse, Extension}; -use bytes::{BufMut, BytesMut}; +use axum::{extract, response::IntoResponse, Extension}; +use bytes::{BufMut, Bytes, BytesMut}; use compact_str::ToCompactString; use corro_types::{ agent::{Agent, ChangeError, KnownDbVersion}, api::{row_to_change, ExecResponse, ExecResult, QueryEvent, Statement}, broadcast::{ChangeV1, Changeset, Timestamp}, change::SqliteValue, + http::{IoBodyStream, LinesBytesCodec}, schema::{apply_schema, parse_sql}, sqlite::SqlitePoolError, }; +use futures::StreamExt; use hyper::StatusCode; use itertools::Itertools; use metrics::counter; -use rusqlite::{named_params, params_from_iter, ToSql, Transaction}; +use rusqlite::{named_params, params_from_iter, Connection, ToSql, Transaction}; +use serde::{Deserialize, Serialize}; use spawn::spawn_counted; use tokio::{ sync::{ - mpsc::{self, channel}, + mpsc::{self, channel, error::SendError, Receiver, Sender}, oneshot, }, task::block_in_place, }; -use tracing::{debug, error, info, trace}; +use tokio_util::{ + codec::{Encoder, FramedRead, LengthDelimitedCodec}, + io::StreamReader, + sync::CancellationToken, +}; +use tracing::{debug, error, info, trace, Instrument}; use corro_types::{ broadcast::{BroadcastInput, BroadcastV1}, @@ -146,6 +156,489 @@ where } } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +#[repr(u8)] +pub enum Stmt { + Prepare(String), + Drop(u32), + Reset(u32), + Columns(u32), + + Execute(u32, Vec), + Query(u32, Vec), + + Next(u32), + + Begin, + Commit, + Rollback, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +#[repr(u8)] +pub enum SqliteResult { + Ok, + Error(String), + + Statement { + id: u32, + params_count: usize, + }, + + Execute { + rows_affected: usize, + last_insert_rowid: i64, + }, + + Columns(Vec), + + // None represents the end of a statement's rows + Row(Option>), +} + +#[derive(Debug, thiserror::Error)] +enum HandleConnError { + #[error(transparent)] + Rusqlite(#[from] rusqlite::Error), + #[error("events channel closed")] + EventsChannelClosed, +} + +impl From> for HandleConnError { + fn from(value: SendError) -> Self { + HandleConnError::EventsChannelClosed + } +} + +#[derive(Clone, Debug)] +struct IncrMap { + map: HashMap, + last: u32, +} + +impl IncrMap { + pub fn insert(&mut self, v: V) -> u32 { + self.last += 1; + self.map.insert(self.last, v); + self.last + } +} + +impl Default for IncrMap { + fn default() -> Self { + Self { + map: Default::default(), + last: Default::default(), + } + } +} + +impl Deref for IncrMap { + type Target = HashMap; + + fn deref(&self) -> &Self::Target { + &self.map + } +} + +impl DerefMut for IncrMap { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.map + } +} + +#[derive(Debug, Default)] +struct ConnCache<'conn> { + prepared: IncrMap>, + cells: Vec, +} + +#[derive(Debug, thiserror::Error)] +enum StmtError { + #[error(transparent)] + Rusqlite(#[from] rusqlite::Error), + #[error("statement not found: {id}")] + StatementNotFound { id: u32 }, +} + +fn handle_stmt<'conn>( + agent: &Agent, + conn: &'conn Connection, + cache: &mut ConnCache<'conn>, + stmt: Stmt, +) -> Result { + match stmt { + Stmt::Prepare(sql) => { + let prepped = conn.prepare(&sql)?; + let params_count = prepped.parameter_count(); + let id = cache.prepared.insert(prepped); + Ok(SqliteResult::Statement { id, params_count }) + } + Stmt::Columns(id) => { + let prepped = cache + .prepared + .get(&id) + .ok_or(StmtError::StatementNotFound { id })?; + Ok(SqliteResult::Columns( + prepped + .column_names() + .into_iter() + .map(|name| name.to_string()) + .collect(), + )) + } + Stmt::Execute(id, params) => { + let prepped = cache + .prepared + .get_mut(&id) + .ok_or(StmtError::StatementNotFound { id })?; + let rows_affected = prepped.execute(params_from_iter(params))?; + Ok(SqliteResult::Execute { + rows_affected, + last_insert_rowid: conn.last_insert_rowid(), + }) + } + Stmt::Query(id, params) => { + let prepped = cache + .prepared + .get_mut(&id) + .ok_or(StmtError::StatementNotFound { id })?; + + for (i, param) in params.into_iter().enumerate() { + prepped.raw_bind_parameter(i + 1, param)?; + } + + Ok(SqliteResult::Ok) + } + Stmt::Next(id) => { + let prepped = cache + .prepared + .get_mut(&id) + .ok_or(StmtError::StatementNotFound { id })?; + + // creates an interator for already-bound statements + let mut rows = prepped.raw_query(); + + let res = match rows.next()? { + Some(row) => { + let col_count = row.as_ref().column_count(); + cache.cells.clear(); + for idx in 0..col_count { + let v = row.get::<_, SqliteValue>(idx)?; + cache.cells.push(v); + } + Ok(SqliteResult::Row(Some(cache.cells.drain(..).collect_vec()))) + } + None => Ok(SqliteResult::Row(None)), + }; + + // prevent running Drop so it doesn't reset everything... + forget(rows); + + res + } + Stmt::Begin => { + conn.execute_batch("BEGIN")?; + Ok(SqliteResult::Ok) + } + Stmt::Commit => { + handle_commit(agent, conn)?; + Ok(SqliteResult::Ok) + } + Stmt::Rollback => { + conn.execute_batch("ROLLBACK")?; + Ok(SqliteResult::Ok) + } + Stmt::Drop(id) => { + cache.prepared.remove(&id); + Ok(SqliteResult::Ok) + } + Stmt::Reset(id) => { + let prepped = cache + .prepared + .get_mut(&id) + .ok_or(StmtError::StatementNotFound { id })?; + + // not sure how to reset a statement otherwise.. + let rows = prepped.raw_query(); + drop(rows); + + Ok(SqliteResult::Ok) + } + } +} + +fn handle_interactive( + agent: &Agent, + mut queries: Receiver, + events: Sender, +) -> Result<(), HandleConnError> { + let conn = match agent.pool().client_dedicated() { + Ok(conn) => conn, + Err(e) => { + return events + .blocking_send(SqliteResult::Error(e.to_string())) + .map_err(HandleConnError::from); + } + }; + + let mut cache = ConnCache::default(); + + while let Some(stmt) = queries.blocking_recv() { + match handle_stmt(agent, &conn, &mut cache, stmt) { + Ok(res) => { + events.blocking_send(res)?; + } + Err(e) => { + events.blocking_send(SqliteResult::Error(e.to_string()))?; + } + } + } + + Ok(()) +} + +fn handle_commit(agent: &Agent, conn: &Connection) -> rusqlite::Result<()> { + let actor_id = agent.actor_id(); + + let ts = Timestamp::from(agent.clock().new_timestamp()); + + let db_version: i64 = conn + .prepare_cached("SELECT crsql_next_db_version()")? + .query_row((), |row| row.get(0))?; + + let has_changes: bool = conn + .prepare_cached( + "SELECT EXISTS(SELECT 1 FROM crsql_changes WHERE site_id IS NULL AND db_version = ?);", + )? + .query_row([db_version], |row| row.get(0))?; + + if !has_changes { + conn.execute_batch("COMMIT")?; + return Ok(()); + } + + let booked = { + agent + .bookie() + .blocking_write("handle_write_tx(for_actor)") + .for_actor(actor_id) + }; + + let last_seq: i64 = conn + .prepare_cached( + "SELECT MAX(seq) FROM crsql_changes WHERE site_id IS NULL AND db_version = ?", + )? + .query_row([db_version], |row| row.get(0))?; + + let mut book_writer = booked.blocking_write("handle_write_tx(book_writer)"); + + let last_version = book_writer.last().unwrap_or_default(); + trace!("last_version: {last_version}"); + let version = last_version + 1; + trace!("version: {version}"); + + conn.prepare_cached( + r#" + INSERT INTO __corro_bookkeeping (actor_id, start_version, db_version, last_seq, ts) + VALUES (:actor_id, :start_version, :db_version, :last_seq, :ts); + "#, + )? + .execute(named_params! { + ":actor_id": actor_id, + ":start_version": version, + ":db_version": db_version, + ":last_seq": last_seq, + ":ts": ts + })?; + + debug!(%actor_id, %version, %db_version, "inserted local bookkeeping row!"); + + conn.execute_batch("COMMIT")?; + + trace!("committed tx, db_version: {db_version}, last_seq: {last_seq:?}"); + + book_writer.insert( + version, + KnownDbVersion::Current { + db_version, + last_seq, + ts, + }, + ); + + let agent = agent.clone(); + + spawn_counted(async move { + let conn = agent.pool().read().await?; + + block_in_place(|| { + // TODO: make this more generic so both sync and local changes can use it. + let mut prepped = conn.prepare_cached(r#" + SELECT "table", pk, cid, val, col_version, db_version, seq, COALESCE(site_id, crsql_site_id()), cl + FROM crsql_changes + WHERE site_id IS NULL + AND db_version = ? + ORDER BY seq ASC + "#)?; + let rows = prepped.query_map([db_version], row_to_change)?; + let chunked = ChunkedChanges::new(rows, 0, last_seq, MAX_CHANGES_BYTE_SIZE); + for changes_seqs in chunked { + match changes_seqs { + Ok((changes, seqs)) => { + for (table_name, count) in changes.iter().counts_by(|change| &change.table) + { + counter!("corro.changes.committed", count as u64, "table" => table_name.to_string(), "source" => "local"); + } + process_subs(&agent, &changes); + + trace!("broadcasting changes: {changes:?} for seq: {seqs:?}"); + + let tx_bcast = agent.tx_bcast().clone(); + tokio::spawn(async move { + if let Err(e) = tx_bcast + .send(BroadcastInput::AddBroadcast(BroadcastV1::Change( + ChangeV1 { + actor_id, + changeset: Changeset::Full { + version, + changes, + seqs, + last_seq, + ts, + }, + }, + ))) + .await + { + error!("could not send change message for broadcast: {e}"); + } + }); + } + Err(e) => { + error!("could not process crsql change (db_version: {db_version}) for broadcast: {e}"); + break; + } + } + } + Ok::<_, rusqlite::Error>(()) + })?; + Ok::<_, eyre::Report>(()) + }); + Ok::<_, rusqlite::Error>(()) +} + +#[tracing::instrument(skip_all)] +pub async fn api_v1_begins( + // axum::extract::RawQuery(raw_query): axum::extract::RawQuery, + Extension(agent): Extension, + req_body: extract::RawBody, +) -> impl IntoResponse { + let (mut body_tx, body) = hyper::Body::channel(); + + let req_body = IoBodyStream { body: req_body.0 }; + + let (queries_tx, queries_rx) = channel(512); + let (events_tx, mut events_rx) = channel(512); + let cancel = CancellationToken::new(); + + tokio::spawn({ + let cancel = cancel.clone(); + let events_tx = events_tx.clone(); + async move { + let _drop_guard = cancel.drop_guard(); + + let mut req_reader = + FramedRead::new(StreamReader::new(req_body), LengthDelimitedCodec::default()); + + while let Some(buf_res) = req_reader.next().await { + match buf_res { + Ok(buf) => match rmp_serde::from_slice(&buf) { + Ok(req) => { + if let Err(e) = queries_tx.send(req).await { + error!("could not send request into channel: {e}"); + if let Err(e) = events_tx + .send(SqliteResult::Error("request channel closed".into())) + .await + { + error!("could not send error event: {e}"); + } + return; + } + } + Err(e) => { + error!("could not parse message: {e}"); + if let Err(e) = events_tx + .send(SqliteResult::Error("request channel closed".into())) + .await + { + error!("could not send error event: {e}"); + } + } + }, + Err(e) => { + error!("could not read buffer from request body: {e}"); + break; + } + } + } + } + .in_current_span() + }); + + // probably a better way to do this... + spawn_counted( + async move { block_in_place(|| handle_interactive(&agent, queries_rx, events_tx)) } + .in_current_span(), + ); + + tokio::spawn(async move { + let mut ser_buf = BytesMut::new(); + let mut encode_buf = BytesMut::new(); + let mut codec = LengthDelimitedCodec::default(); + + while let Some(event) = events_rx.recv().await { + match rmp_serde::encode::write(&mut (&mut ser_buf).writer(), &event) { + Ok(_) => match codec.encode(ser_buf.split().freeze(), &mut encode_buf) { + Ok(_) => { + if let Err(e) = body_tx.send_data(encode_buf.split().freeze()).await { + error!("could not send tx event to response body: {e}"); + return; + } + } + Err(e) => { + error!("could not encode event: {e}"); + if let Err(e) = body_tx + .send_data(Bytes::from(r#"{"error": "could not encode event"}"#)) + .await + { + error!("could not send encoding error to body: {e}"); + return; + } + } + }, + Err(e) => { + error!("could not serialize event: {e}"); + if let Err(e) = body_tx + .send_data(Bytes::from(r#"{"error": "could not serialize event"}"#)) + .await + { + error!("could not send serialize error to body: {e}"); + return; + } + } + } + } + }); + + hyper::Response::builder() + .status(hyper::StatusCode::SWITCHING_PROTOCOLS) + .body(body) + .unwrap() +} + const MAX_CHANGES_BYTE_SIZE: usize = 8 * 1024; pub async fn make_broadcastable_changes( @@ -710,9 +1203,11 @@ pub async fn api_v1_db_schema( #[cfg(test)] mod tests { use bytes::Bytes; + use corro_tests::launch_test_agent; use corro_types::{api::RowId, config::Config, schema::SqliteType}; use futures::Stream; use http_body::{combinators::UnsyncBoxBody, Body}; + use spawn::wait_for_all_pending_handles; use tokio::sync::mpsc::error::TryRecvError; use tokio_util::codec::{Decoder, LinesCodec}; use tripwire::Tripwire; @@ -1244,4 +1739,41 @@ mod tests { assert_eq!(chunker.next(), None); } + + #[tokio::test(flavor = "multi_thread")] + async fn test_interactive() -> eyre::Result<()> { + _ = tracing_subscriber::fmt::try_init(); + let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple(); + let ta = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?; + + let (q_tx, q_rx) = channel(1); + let (e_tx, mut e_rx) = channel(1); + + spawn_counted(async move { block_in_place(|| handle_interactive(&ta.agent, q_rx, e_tx)) }); + + q_tx.send(Stmt::Prepare("SELECT 123".into())).await?; + let e = e_rx.recv().await.unwrap(); + println!("e: {e:?}"); + + q_tx.send(Stmt::Query(1, vec![])).await?; + let e = e_rx.recv().await.unwrap(); + println!("e: {e:?}"); + + q_tx.send(Stmt::Next(1)).await?; + let e = e_rx.recv().await.unwrap(); + assert_eq!(e, SqliteResult::Row(Some(vec![SqliteValue::Integer(123)]))); + + q_tx.send(Stmt::Next(1)).await?; + let e = e_rx.recv().await.unwrap(); + assert_eq!(e, SqliteResult::Row(None)); + + drop(q_tx); + drop(e_rx); + + tripwire_tx.send(()).await.ok(); + tripwire_worker.await; + wait_for_all_pending_handles().await; + + Ok(()) + } } diff --git a/crates/corro-agent/src/api/public/pubsub.rs b/crates/corro-agent/src/api/public/pubsub.rs index 5c523f9c..dfc8cd45 100644 --- a/crates/corro-agent/src/api/public/pubsub.rs +++ b/crates/corro-agent/src/api/public/pubsub.rs @@ -186,7 +186,7 @@ pub async fn process_sub_channel( }; // get a dedicated connection - let conn = match agent.pool().dedicated().await { + let conn = match agent.pool().dedicated() { Ok(conn) => conn, Err(e) => { error!("could not acquire dedicated connection for subscription cleanup: {e}"); @@ -428,7 +428,7 @@ pub async fn catch_up_sub( let last_query_event = { let mut buf = BytesMut::new(); - let mut conn = match agent.pool().dedicated().await { + let mut conn = match agent.pool().dedicated() { Ok(conn) => conn, Err(e) => { evt_tx.send(error_to_query_event_bytes(&mut buf, e)).await?; @@ -532,7 +532,7 @@ pub async fn upsert_sub( return Err(MatcherUpsertError::SubFromWithoutMatcher); } - let conn = agent.pool().dedicated().await?; + let conn = agent.pool().dedicated()?; let (evt_tx, evt_rx) = mpsc::channel(512); diff --git a/crates/corro-pg/Cargo.toml b/crates/corro-pg/Cargo.toml new file mode 100644 index 00000000..528be5f8 --- /dev/null +++ b/crates/corro-pg/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "corro-pg" +version = "0.1.0" +edition = "2021" + +[dependencies] +bytes = { workspace = true } +compact_str = { workspace = true } +corro-types = { path = "../corro-types" } +futures = { workspace = true } +rusqlite = { workspace = true } +sqlparser = { version = "0.38" } +pgwire = { version = "0.16.1" } +thiserror = { workspace = true } +tokio = { workspace = true } +tokio-util = { workspace = true } +tracing = { workspace = true } +time = { workspace = true } +phf = "*" +postgres-types = { version = "0.2", features = ["with-time-0_3"] } + +[dev-dependencies] +tracing-subscriber = { workspace = true } +tempfile = { workspace = true } +corro-tests = { path = "../corro-tests" } +tokio-postgres = { version = "0.7.10" } \ No newline at end of file diff --git a/crates/corro-pg/src/lib.rs b/crates/corro-pg/src/lib.rs new file mode 100644 index 00000000..b373c451 --- /dev/null +++ b/crates/corro-pg/src/lib.rs @@ -0,0 +1,260 @@ +pub mod proto; +pub mod proto_ext; +pub mod sql_state; + +use std::{collections::HashMap, net::SocketAddr}; + +use compact_str::CompactString; +use corro_types::{agent::Agent, change::SqliteValue, config::PgConfig}; +use futures::{Sink, SinkExt, StreamExt}; +use pgwire::{ + api::ClientInfoHolder, + error::{ErrorInfo, PgWireError}, + messages::{ + response::{ReadyForQuery, READY_STATUS_IDLE}, + PgWireBackendMessage, PgWireFrontendMessage, + }, + tokio::PgWireMessageServerCodec, +}; +use rusqlite::Statement; +use tokio::net::TcpListener; +use tokio_util::codec::{Framed, FramedRead}; +use tracing::{debug, info}; + +use crate::{ + proto::{Bind, ConnectionCodec, ProtocolError}, + sql_state::SqlState, +}; + +type BoxError = Box; + +struct PgServer { + local_addr: SocketAddr, +} + +async fn start(agent: Agent, pg: PgConfig) -> Result { + let mut server = TcpListener::bind(pg.bind_addr).await?; + let local_addr = server.local_addr()?; + + tokio::spawn(async move { + loop { + let (conn, remote_addr) = server.accept().await?; + info!("accepted a conn, addr: {remote_addr}"); + + let agent = agent.clone(); + tokio::spawn(async move { + let mut framed = Framed::new( + conn, + PgWireMessageServerCodec::new(ClientInfoHolder::new(remote_addr, false)), + ); + + let msg = framed.next().await.unwrap()?; + + match msg { + PgWireFrontendMessage::Startup(startup) => { + info!("received startup message: {startup:?}"); + } + _ => { + framed + .send(PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "FATAL".into(), + SqlState::PROTOCOL_VIOLATION.code().into(), + "expected startup message".into(), + ) + .into(), + )) + .await?; + return Ok(()); + } + } + + framed + .send(PgWireBackendMessage::Authentication( + pgwire::messages::startup::Authentication::Ok, + )) + .await?; + framed + .send(PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new( + READY_STATUS_IDLE, + ))) + .await?; + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + std::thread::spawn(move || -> Result<(), BoxError> { + rt.block_on(async move { + let conn = rusqlite::Connection::open_in_memory().unwrap(); + + let mut prepared: HashMap)> = HashMap::new(); + + let mut portals: HashMap = + HashMap::new(); + + let mut row_cache: Vec = vec![]; + + while let Some(decode_res) = framed.next().await { + let msg = match decode_res { + Ok(msg) => msg, + Err(PgWireError::IoError(io_error)) => { + debug!("postgres io error: {io_error}"); + break; + } + // Err(ProtocolError::ParserError) => { + // framed + // .send(proto::ErrorResponse::new( + // proto::SqlState::SyntaxError, + // proto::Severity::Error, + // "parsing error", + // )) + // .await?; + // continue; + // } + Err(e) => { + framed + .send(PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "FATAL".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + )) + .await?; + break; + } + }; + + match msg { + PgWireFrontendMessage::Startup(_) => { + framed + .send(PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "FATAL".into(), + SqlState::PROTOCOL_VIOLATION.code().into(), + "unexpected startup message".into(), + ) + .into(), + )) + .await?; + continue; + } + PgWireFrontendMessage::Parse(parse) => { + if let Err(e) = conn.prepare_cached(parse.query()) { + framed + .send(PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + )) + .await?; + continue; + } + + prepared.insert( + parse.name().as_deref().unwrap_or("").into(), + (parse.query().clone(), parse.type_oids().clone()), + ); + } + PgWireFrontendMessage::Describe(_) => todo!(), + PgWireFrontendMessage::Bind(bind) => { + let portal_name = bind + .portal_name() + .as_deref() + .map(CompactString::from) + .unwrap_or_default(); + + let stmt_name = bind.statement_name().as_deref().unwrap_or(""); + } + PgWireFrontendMessage::Sync(_) => todo!(), + PgWireFrontendMessage::Execute(_) => todo!(), + PgWireFrontendMessage::Query(_) => todo!(), + PgWireFrontendMessage::Terminate(_) => todo!(), + + PgWireFrontendMessage::PasswordMessageFamily(_) => todo!(), + PgWireFrontendMessage::Close(_) => todo!(), + PgWireFrontendMessage::Flush(_) => todo!(), + PgWireFrontendMessage::CopyData(_) => todo!(), + PgWireFrontendMessage::CopyFail(_) => todo!(), + PgWireFrontendMessage::CopyDone(_) => todo!(), + } + } + + Ok::<_, BoxError>(()) + }) + }) + .join() + .unwrap()?; + + Ok::<_, BoxError>(()) + }); + } + + info!("postgres server done"); + + Ok::<_, BoxError>(()) + }); + + return Ok(PgServer { local_addr }); +} + +// #[cfg(test)] +// mod tests { +// use tokio_postgres::NoTls; + +// use super::*; + +// #[tokio::test] +// async fn test_pg() -> Result<(), BoxError> { +// _ = tracing_subscriber::fmt::try_init(); +// let server = TcpListener::bind("127.0.0.1:0").await?; +// let local_addr = server.local_addr()?; + +// let conn_str = format!( +// "host={} port={} user=testuser", +// local_addr.ip(), +// local_addr.port() +// ); + +// let client_task = tokio::spawn(async move { +// let (client, client_conn) = tokio_postgres::connect(&conn_str, NoTls).await?; +// println!("client is ready!"); +// tokio::spawn(client_conn); + +// client.prepare("SELECT 1").await?; +// Ok::<_, BoxError>(()) +// }); + +// let (conn, remote_addr) = server.accept().await?; +// println!("accepted a conn, addr: {remote_addr}"); + +// let mut framed = Framed::new(conn, ConnectionCodec::new()); + +// let msg = framed.next().await.unwrap()?; +// println!("recv msg: {msg:?}"); + +// framed.send(proto::AuthenticationOk).await?; +// framed.send(proto::ReadyForQuery).await?; + +// let msg = framed.next().await.unwrap()?; +// println!("recv msg: {msg:?}"); + +// let query = if let PgWireFrontendMessage::Parse(Parse { query, .. }) = msg { +// query +// } else { +// panic!("unexpected message"); +// }; + +// println!("query: {query}"); + +// assert!(client_task.await?.is_ok()); + +// Ok(()) +// } +// } diff --git a/crates/corro-pg/src/proto.rs b/crates/corro-pg/src/proto.rs new file mode 100644 index 00000000..29cfa126 --- /dev/null +++ b/crates/corro-pg/src/proto.rs @@ -0,0 +1,650 @@ +//! Contains types that represent the core Postgres wire protocol. + +// this module requires a lot more work to document +// may want to build this automatically from Postgres docs if possible +#![allow(missing_docs)] + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::convert::TryFrom; +use std::fmt::Display; +use std::mem::size_of; +use std::{collections::HashMap, convert::TryInto}; +use tokio_util::codec::{Decoder, Encoder}; + +macro_rules! data_types { + ($($name:ident = $oid:expr, $size: expr)*) => { + #[derive(Debug, Copy, Clone)] + /// Describes a Postgres data type. + pub enum DataTypeOid { + $( + #[allow(missing_docs)] + $name, + )* + /// A type which is not known to this crate. + Unknown(u32), + } + + impl DataTypeOid { + /// Fetch the size in bytes for this data type. + /// Variably-sized types return -1. + pub fn size_bytes(&self) -> i16 { + match self { + $( + Self::$name => $size, + )* + Self::Unknown(_) => unimplemented!(), + } + } + } + + impl From for DataTypeOid { + fn from(value: u32) -> Self { + match value { + $( + $oid => Self::$name, + )* + other => Self::Unknown(other), + } + } + } + + impl From for u32 { + fn from(value: DataTypeOid) -> Self { + match value { + $( + DataTypeOid::$name => $oid, + )* + DataTypeOid::Unknown(other) => other, + } + } + } + }; +} + +// For oid see: +// https://github.com/sfackler/rust-postgres/blob/master/postgres-types/src/type_gen.rs +data_types! { + Unspecified = 0, 0 + + Bool = 16, 1 + + Int2 = 21, 2 + Int4 = 23, 4 + Int8 = 20, 8 + + Float4 = 700, 4 + Float8 = 701, 8 + + Date = 1082, 4 + Timestamp = 1114, 8 + + Text = 25, -1 +} + +/// Describes how to format a given value or set of values. +#[derive(Debug, Copy, Clone)] +pub enum FormatCode { + /// Use the stable text representation. + Text = 0, + /// Use the less-stable binary representation. + Binary = 1, +} + +impl TryFrom for FormatCode { + type Error = ProtocolError; + + fn try_from(value: i16) -> Result { + match value { + 0 => Ok(FormatCode::Text), + 1 => Ok(FormatCode::Binary), + other => Err(ProtocolError::InvalidFormatCode(other)), + } + } +} + +#[derive(Debug)] +pub struct Startup { + pub requested_protocol_version: (i16, i16), + pub parameters: HashMap, +} + +#[derive(Debug)] +pub enum Describe { + Portal(String), + PreparedStatement(String), +} + +#[derive(Debug)] +pub struct Parse { + pub prepared_statement_name: String, + pub query: String, + pub parameter_types: Vec, +} + +#[derive(Debug)] +pub enum BindFormat { + All(FormatCode), + PerColumn(Vec), +} + +#[derive(Debug)] +pub struct Bind { + pub portal: String, + pub prepared_statement_name: String, + pub parameter_values: Vec, + pub result_format: BindFormat, +} + +#[derive(Debug)] +pub enum BindValue { + Text(String), + Binary(Bytes), +} + +#[derive(Debug)] +pub struct Execute { + pub portal: String, + pub max_rows: Option, +} + +#[derive(Debug)] +pub enum FrontendMessage { + SSLRequest, // for SSL negotiation + Startup(Startup), + Parse(Parse), + Describe(Describe), + Bind(Bind), + Sync, + Execute(Execute), + Query(String), + Terminate, +} + +pub trait BackendMessage: std::fmt::Debug { + const TAG: u8; + + fn encode(&self, dst: &mut BytesMut); +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SqlState { + SuccessfulCompletion, + FeatureNotSupported, + InvalidCursorName, + ConnectionException, + InvalidSQLStatementName, + DataException, + ProtocolViolation, + SyntaxError, + InvalidDatetimeFormat, +} + +impl SqlState { + pub fn code(&self) -> &str { + match self { + Self::SuccessfulCompletion => "00000", + Self::FeatureNotSupported => "0A000", + Self::InvalidCursorName => "34000", + Self::ConnectionException => "08000", + Self::InvalidSQLStatementName => "26000", + Self::DataException => "22000", + Self::ProtocolViolation => "08P01", + Self::SyntaxError => "42601", + Self::InvalidDatetimeFormat => "22007", + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Severity { + Error, + Fatal, +} + +impl Severity { + pub fn code(&self) -> &str { + match self { + Self::Fatal => "FATAL", + Self::Error => "ERROR", + } + } +} + +#[derive(thiserror::Error, Debug, Clone)] +pub struct ErrorResponse { + pub sql_state: SqlState, + pub severity: Severity, + pub message: String, +} + +impl ErrorResponse { + pub fn new(sql_state: SqlState, severity: Severity, message: impl Into) -> Self { + ErrorResponse { + sql_state, + severity, + message: message.into(), + } + } + + pub fn error(sql_state: SqlState, message: impl Into) -> Self { + Self::new(sql_state, Severity::Error, message) + } + + pub fn fatal(sql_state: SqlState, message: impl Into) -> Self { + Self::new(sql_state, Severity::Error, message) + } +} + +impl Display for ErrorResponse { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "error") + } +} + +impl BackendMessage for ErrorResponse { + const TAG: u8 = b'E'; + + fn encode(&self, dst: &mut BytesMut) { + dst.put_u8(b'C'); + dst.put_slice(self.sql_state.code().as_bytes()); + dst.put_u8(0); + dst.put_u8(b'S'); + dst.put_slice(self.severity.code().as_bytes()); + dst.put_u8(0); + dst.put_u8(b'M'); + dst.put_slice(self.message.as_bytes()); + dst.put_u8(0); + + dst.put_u8(0); // tag + } +} + +#[derive(Debug)] +pub struct ParameterDescription {} + +impl BackendMessage for ParameterDescription { + const TAG: u8 = b't'; + + fn encode(&self, dst: &mut BytesMut) { + dst.put_i16(0); + } +} + +#[derive(Debug, Clone)] +pub struct FieldDescription { + pub name: String, + pub data_type: DataTypeOid, +} + +#[derive(Debug, Clone)] +pub struct RowDescription { + pub fields: Vec, + pub format_code: FormatCode, +} + +impl BackendMessage for RowDescription { + const TAG: u8 = b'T'; + + fn encode(&self, dst: &mut BytesMut) { + dst.put_i16(self.fields.len() as i16); + for field in &self.fields { + dst.put_slice(field.name.as_bytes()); + dst.put_u8(0); + dst.put_i32(0); // table oid + dst.put_i16(0); // column attr number + dst.put_u32(field.data_type.into()); + dst.put_i16(field.data_type.size_bytes()); + dst.put_i32(-1); // data type modifier + dst.put_i16(self.format_code as i16); + } + } +} + +#[derive(Debug)] +pub struct AuthenticationOk; + +impl BackendMessage for AuthenticationOk { + const TAG: u8 = b'R'; + + fn encode(&self, dst: &mut BytesMut) { + dst.put_i32(0); + } +} + +#[derive(Debug)] +pub struct ReadyForQuery; + +impl BackendMessage for ReadyForQuery { + const TAG: u8 = b'Z'; + + fn encode(&self, dst: &mut BytesMut) { + dst.put_u8(b'I'); + } +} + +#[derive(Debug)] +pub struct ParseComplete; + +impl BackendMessage for ParseComplete { + const TAG: u8 = b'1'; + + fn encode(&self, _dst: &mut BytesMut) {} +} + +#[derive(Debug)] +pub struct BindComplete; + +impl BackendMessage for BindComplete { + const TAG: u8 = b'2'; + + fn encode(&self, _dst: &mut BytesMut) {} +} + +#[derive(Debug)] +pub struct NoData; + +impl BackendMessage for NoData { + const TAG: u8 = b'n'; + + fn encode(&self, _dst: &mut BytesMut) {} +} + +#[derive(Debug)] +pub struct EmptyQueryResponse; + +impl BackendMessage for EmptyQueryResponse { + const TAG: u8 = b'I'; + + fn encode(&self, _dst: &mut BytesMut) {} +} + +#[derive(Debug)] +pub struct CommandComplete { + pub command_tag: String, +} + +impl BackendMessage for CommandComplete { + const TAG: u8 = b'C'; + + fn encode(&self, dst: &mut BytesMut) { + dst.put_slice(self.command_tag.as_bytes()); + dst.put_u8(0); + } +} + +#[derive(Debug)] +pub struct ParameterStatus { + name: String, + value: String, +} + +impl BackendMessage for ParameterStatus { + const TAG: u8 = b'S'; + + fn encode(&self, dst: &mut BytesMut) { + dst.put_slice(self.name.as_bytes()); + dst.put_u8(0); + dst.put_slice(self.value.as_bytes()); + dst.put_u8(0); + } +} + +impl ParameterStatus { + pub fn new(name: impl Into, value: impl Into) -> Self { + Self { + name: name.into(), + value: value.into(), + } + } +} + +#[derive(Default, Debug)] +pub struct ConnectionCodec { + // most state tracking is handled at a higher level + // however, the actual wire format uses a different header for startup vs normal messages + // so we need to be able to differentiate inside the decoder + startup_received: bool, +} + +impl ConnectionCodec { + pub fn new() -> Self { + Self { + startup_received: false, + } + } +} + +#[derive(thiserror::Error, Debug)] +pub enum ProtocolError { + #[error("io error: {0}")] + Io(#[from] std::io::Error), + #[error("utf8 error: {0}")] + Utf8(#[from] std::string::FromUtf8Error), + #[error("parsing error")] + ParserError, + #[error("invalid message type: {0}")] + InvalidMessageType(u8), + #[error("invalid format code: {0}")] + InvalidFormatCode(i16), +} + +// length prefix, two version components +const STARTUP_HEADER_SIZE: usize = size_of::() + (size_of::() * 2); +// message tag, length prefix +const MESSAGE_HEADER_SIZE: usize = size_of::() + size_of::(); + +impl Decoder for ConnectionCodec { + type Item = FrontendMessage; + type Error = ProtocolError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + if !self.startup_received { + if src.len() < STARTUP_HEADER_SIZE { + return Ok(None); + } + + let mut header_buf = src.clone(); + let message_len = header_buf.get_i32() as usize; + let protocol_version_major = header_buf.get_i16(); + let protocol_version_minor = header_buf.get_i16(); + + if protocol_version_major == 1234i16 && protocol_version_minor == 5679i16 { + src.advance(STARTUP_HEADER_SIZE); + return Ok(Some(FrontendMessage::SSLRequest)); + } + + if src.len() < message_len { + src.reserve(message_len - src.len()); + return Ok(None); + } + + src.advance(STARTUP_HEADER_SIZE); + + let mut parameters = HashMap::new(); + + let mut param_str_start_pos = 0; + let mut current_key = None; + for (i, &blah) in src.iter().enumerate() { + if blah == 0 { + let string_value = String::from_utf8(src[param_str_start_pos..i].to_owned())?; + param_str_start_pos = i + 1; + + current_key = match current_key { + Some(key) => { + parameters.insert(key, string_value); + None + } + None => Some(string_value), + } + } + } + + src.advance(message_len - STARTUP_HEADER_SIZE); + + self.startup_received = true; + return Ok(Some(FrontendMessage::Startup(Startup { + requested_protocol_version: (protocol_version_major, protocol_version_minor), + parameters, + }))); + } + + if src.len() < MESSAGE_HEADER_SIZE { + src.reserve(MESSAGE_HEADER_SIZE); + return Ok(None); + } + + let mut header_buf = src.clone(); + let message_tag = header_buf.get_u8(); + let message_len = header_buf.get_i32() as usize; + + if src.len() < message_len { + src.reserve(message_len - src.len()); + return Ok(None); + } + + src.advance(MESSAGE_HEADER_SIZE); + + let read_cstr = |src: &mut BytesMut| -> Result { + let next_null = src + .iter() + .position(|&b| b == 0) + .ok_or(ProtocolError::ParserError)?; + let bytes = src[..next_null].to_owned(); + src.advance(bytes.len() + 1); + Ok(String::from_utf8(bytes)?) + }; + + let message = match message_tag { + b'P' => { + let prepared_statement_name = read_cstr(src)?; + let query = read_cstr(src)?; + let num_params = src.get_i16(); + let _params: Vec<_> = (0..num_params).map(|_| src.get_u32()).collect(); + + FrontendMessage::Parse(Parse { + prepared_statement_name, + query, + parameter_types: Vec::new(), + }) + } + b'D' => { + let target_type = src.get_u8(); + let name = read_cstr(src)?; + + FrontendMessage::Describe(match target_type { + b'P' => Describe::Portal(name), + b'S' => Describe::PreparedStatement(name), + _ => return Err(ProtocolError::ParserError), + }) + } + b'S' => FrontendMessage::Sync, + b'B' => { + let portal = read_cstr(src)?; + let prepared_statement_name = read_cstr(src)?; + + let num_param_format_codes = src.get_i16(); + + let mut format_codes: Vec = vec![]; + for _ in 0..num_param_format_codes { + format_codes.push(src.get_i16().try_into()?); + } + + let num_params = src.get_i16(); + let mut params = vec![]; + + let mut last_error = None; + + for i in 0..num_params { + let param_len = src.get_i32() as usize; + let format_code = if num_param_format_codes == 0 { + FormatCode::Text + } else if num_param_format_codes == 1 { + format_codes[0] + } else if format_codes.len() >= (i + 1) as usize { + format_codes[i as usize] + } else { + last_error = Some(ProtocolError::ParserError); + FormatCode::Text + }; + + let bytes = src.copy_to_bytes(param_len); + params.push(match format_code { + FormatCode::Binary => BindValue::Binary(bytes), + FormatCode::Text => match String::from_utf8(bytes.to_vec()) { + Ok(s) => BindValue::Text(s), + Err(e) => { + last_error = Some(ProtocolError::Utf8(e)); + continue; + } + }, + }); + } + + let result_format = match src.get_i16() { + 0 => BindFormat::All(FormatCode::Text), + 1 => BindFormat::All(src.get_i16().try_into()?), + n => { + let mut result_format_codes = Vec::new(); + for _ in 0..n { + result_format_codes.push(src.get_i16().try_into()?); + } + BindFormat::PerColumn(result_format_codes) + } + }; + + if let Some(e) = last_error { + return Err(e); + } + + FrontendMessage::Bind(Bind { + portal, + prepared_statement_name, + parameter_values: params, + result_format, + }) + } + b'E' => { + let portal = read_cstr(src)?; + let max_rows = match src.get_i32() { + 0 => None, + other => Some(other), + }; + + FrontendMessage::Execute(Execute { portal, max_rows }) + } + b'Q' => { + let query = read_cstr(src)?; + FrontendMessage::Query(query) + } + b'X' => FrontendMessage::Terminate, + other => return Err(ProtocolError::InvalidMessageType(other)), + }; + + Ok(Some(message)) + } +} + +impl Encoder for ConnectionCodec { + type Error = ProtocolError; + + fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> { + let mut body = BytesMut::new(); + item.encode(&mut body); + + dst.put_u8(T::TAG); + dst.put_i32((body.len() + 4) as i32); + dst.put_slice(&body); + Ok(()) + } +} + +pub struct SSLResponse(pub bool); + +impl Encoder for ConnectionCodec { + type Error = ProtocolError; + + fn encode(&mut self, item: SSLResponse, dst: &mut BytesMut) -> Result<(), Self::Error> { + dst.put_u8(if item.0 { b'S' } else { b'N' }); + Ok(()) + } +} diff --git a/crates/corro-pg/src/proto_ext.rs b/crates/corro-pg/src/proto_ext.rs new file mode 100644 index 00000000..d7582a1d --- /dev/null +++ b/crates/corro-pg/src/proto_ext.rs @@ -0,0 +1,163 @@ +//! Contains extensions that make working with the Postgres protocol simpler or more efficient. + +use crate::proto::{ConnectionCodec, FormatCode, ProtocolError, RowDescription}; +use bytes::{BufMut, BytesMut}; +use time::{Date, PrimitiveDateTime}; +use tokio_util::codec::Encoder; + +/// Supports batched rows for e.g. returning portal result sets. +/// +/// NB: this struct only performs limited validation of column consistency across rows. +pub struct DataRowBatch { + format_code: FormatCode, + num_cols: usize, + num_rows: usize, + data: BytesMut, + row: BytesMut, +} + +impl DataRowBatch { + /// Creates a new row batch using the given format code, requiring a certain number of columns per row. + pub fn new(format_code: FormatCode, num_cols: usize) -> Self { + Self { + format_code, + num_cols, + num_rows: 0, + data: BytesMut::new(), + row: BytesMut::new(), + } + } + + /// Creates a [DataRowBatch] from the given [RowDescription]. + pub fn from_row_desc(desc: &RowDescription) -> Self { + Self::new(desc.format_code, desc.fields.len()) + } + + /// Starts writing a new row. + /// + /// Returns a [DataRowWriter] that is responsible for the actual value encoding. + pub fn create_row(&mut self) -> DataRowWriter { + self.num_rows += 1; + DataRowWriter::new(self) + } + + /// Returns the number of rows currently written to this batch. + pub fn num_rows(&self) -> usize { + self.num_rows + } +} + +macro_rules! primitive_write { + ($name: ident, $type: ident) => { + #[allow(missing_docs)] + pub fn $name(&mut self, val: $type) { + match self.parent.format_code { + FormatCode::Text => self.write_value(&val.to_string().into_bytes()), + FormatCode::Binary => self.write_value(&val.to_be_bytes()), + }; + } + }; +} + +/// Temporarily leased from a [DataRowBatch] to encode a single row. +pub struct DataRowWriter<'a> { + current_col: usize, + parent: &'a mut DataRowBatch, +} + +impl<'a> DataRowWriter<'a> { + fn new(parent: &'a mut DataRowBatch) -> Self { + parent.row.put_i16(parent.num_cols as i16); + Self { + current_col: 0, + parent, + } + } + + fn write_value(&mut self, data: &[u8]) { + self.current_col += 1; + self.parent.row.put_i32(data.len() as i32); + self.parent.row.put_slice(data); + } + + /// Writes a null value for the next column. + pub fn write_null(&mut self) { + self.current_col += 1; + self.parent.row.put_i32(-1); + } + + /// Writes a string value for the next column. + pub fn write_string(&mut self, val: &str) { + self.write_value(val.as_bytes()); + } + + /// Writes a bool value for the next column. + pub fn write_bool(&mut self, val: bool) { + match self.parent.format_code { + FormatCode::Text => self.write_value(if val { "t" } else { "f" }.as_bytes()), + FormatCode::Binary => { + self.current_col += 1; + self.parent.row.put_u8(val as u8); + } + }; + } + + fn pg_date_epoch() -> Date { + Date::from_calendar_date(2000, time::Month::January, 1) + .expect("failed to create pg date epoch") + } + + fn pg_timestamp_epoch() -> PrimitiveDateTime { + Self::pg_date_epoch() + .with_hms(0, 0, 0) + .expect("failed to create pg timestamp epoch") + } + + /// Writes a date value for the next column. + pub fn write_date(&mut self, val: Date) { + match self.parent.format_code { + FormatCode::Binary => { + self.write_int4((val - Self::pg_date_epoch()).whole_days() as i32) + } + FormatCode::Text => self.write_string(&val.to_string()), + } + } + + /// Writes a timestamp value for the next column. + pub fn write_timestamp(&mut self, val: PrimitiveDateTime) { + match self.parent.format_code { + FormatCode::Binary => { + self.write_int8((val - Self::pg_timestamp_epoch()).whole_microseconds() as i64); + } + FormatCode::Text => self.write_string(&val.to_string()), + } + } + + primitive_write!(write_int2, i16); + primitive_write!(write_int4, i32); + primitive_write!(write_int8, i64); + primitive_write!(write_float4, f32); + primitive_write!(write_float8, f64); +} + +impl<'a> Drop for DataRowWriter<'a> { + fn drop(&mut self) { + assert_eq!( + self.parent.num_cols, self.current_col, + "dropped a row writer with an invalid number of columns" + ); + + self.parent.data.put_u8(b'D'); + self.parent.data.put_i32((self.parent.row.len() + 4) as i32); + self.parent.data.extend(self.parent.row.split()); + } +} + +impl Encoder for ConnectionCodec { + type Error = ProtocolError; + + fn encode(&mut self, item: DataRowBatch, dst: &mut BytesMut) -> Result<(), Self::Error> { + dst.extend(item.data); + Ok(()) + } +} diff --git a/crates/corro-pg/src/sql_state.rs b/crates/corro-pg/src/sql_state.rs new file mode 100644 index 00000000..d8300247 --- /dev/null +++ b/crates/corro-pg/src/sql_state.rs @@ -0,0 +1,1668 @@ +/// A SQLSTATE error code +#[derive(PartialEq, Eq, Clone, Debug)] +pub struct SqlState(Inner); + +impl SqlState { + /// Creates a `SqlState` from its error code. + pub fn from_code(s: &str) -> SqlState { + match SQLSTATE_MAP.get(s) { + Some(state) => state.clone(), + None => SqlState(Inner::Other(s.into())), + } + } + + /// Returns the error code corresponding to the `SqlState`. + pub fn code(&self) -> &str { + match &self.0 { + Inner::E00000 => "00000", + Inner::E01000 => "01000", + Inner::E0100C => "0100C", + Inner::E01008 => "01008", + Inner::E01003 => "01003", + Inner::E01007 => "01007", + Inner::E01006 => "01006", + Inner::E01004 => "01004", + Inner::E01P01 => "01P01", + Inner::E02000 => "02000", + Inner::E02001 => "02001", + Inner::E03000 => "03000", + Inner::E08000 => "08000", + Inner::E08003 => "08003", + Inner::E08006 => "08006", + Inner::E08001 => "08001", + Inner::E08004 => "08004", + Inner::E08007 => "08007", + Inner::E08P01 => "08P01", + Inner::E09000 => "09000", + Inner::E0A000 => "0A000", + Inner::E0B000 => "0B000", + Inner::E0F000 => "0F000", + Inner::E0F001 => "0F001", + Inner::E0L000 => "0L000", + Inner::E0LP01 => "0LP01", + Inner::E0P000 => "0P000", + Inner::E0Z000 => "0Z000", + Inner::E0Z002 => "0Z002", + Inner::E20000 => "20000", + Inner::E21000 => "21000", + Inner::E22000 => "22000", + Inner::E2202E => "2202E", + Inner::E22021 => "22021", + Inner::E22008 => "22008", + Inner::E22012 => "22012", + Inner::E22005 => "22005", + Inner::E2200B => "2200B", + Inner::E22022 => "22022", + Inner::E22015 => "22015", + Inner::E2201E => "2201E", + Inner::E22014 => "22014", + Inner::E22016 => "22016", + Inner::E2201F => "2201F", + Inner::E2201G => "2201G", + Inner::E22018 => "22018", + Inner::E22007 => "22007", + Inner::E22019 => "22019", + Inner::E2200D => "2200D", + Inner::E22025 => "22025", + Inner::E22P06 => "22P06", + Inner::E22010 => "22010", + Inner::E22023 => "22023", + Inner::E22013 => "22013", + Inner::E2201B => "2201B", + Inner::E2201W => "2201W", + Inner::E2201X => "2201X", + Inner::E2202H => "2202H", + Inner::E2202G => "2202G", + Inner::E22009 => "22009", + Inner::E2200C => "2200C", + Inner::E2200G => "2200G", + Inner::E22004 => "22004", + Inner::E22002 => "22002", + Inner::E22003 => "22003", + Inner::E2200H => "2200H", + Inner::E22026 => "22026", + Inner::E22001 => "22001", + Inner::E22011 => "22011", + Inner::E22027 => "22027", + Inner::E22024 => "22024", + Inner::E2200F => "2200F", + Inner::E22P01 => "22P01", + Inner::E22P02 => "22P02", + Inner::E22P03 => "22P03", + Inner::E22P04 => "22P04", + Inner::E22P05 => "22P05", + Inner::E2200L => "2200L", + Inner::E2200M => "2200M", + Inner::E2200N => "2200N", + Inner::E2200S => "2200S", + Inner::E2200T => "2200T", + Inner::E22030 => "22030", + Inner::E22031 => "22031", + Inner::E22032 => "22032", + Inner::E22033 => "22033", + Inner::E22034 => "22034", + Inner::E22035 => "22035", + Inner::E22036 => "22036", + Inner::E22037 => "22037", + Inner::E22038 => "22038", + Inner::E22039 => "22039", + Inner::E2203A => "2203A", + Inner::E2203B => "2203B", + Inner::E2203C => "2203C", + Inner::E2203D => "2203D", + Inner::E2203E => "2203E", + Inner::E2203F => "2203F", + Inner::E2203G => "2203G", + Inner::E23000 => "23000", + Inner::E23001 => "23001", + Inner::E23502 => "23502", + Inner::E23503 => "23503", + Inner::E23505 => "23505", + Inner::E23514 => "23514", + Inner::E23P01 => "23P01", + Inner::E24000 => "24000", + Inner::E25000 => "25000", + Inner::E25001 => "25001", + Inner::E25002 => "25002", + Inner::E25008 => "25008", + Inner::E25003 => "25003", + Inner::E25004 => "25004", + Inner::E25005 => "25005", + Inner::E25006 => "25006", + Inner::E25007 => "25007", + Inner::E25P01 => "25P01", + Inner::E25P02 => "25P02", + Inner::E25P03 => "25P03", + Inner::E26000 => "26000", + Inner::E27000 => "27000", + Inner::E28000 => "28000", + Inner::E28P01 => "28P01", + Inner::E2B000 => "2B000", + Inner::E2BP01 => "2BP01", + Inner::E2D000 => "2D000", + Inner::E2F000 => "2F000", + Inner::E2F005 => "2F005", + Inner::E2F002 => "2F002", + Inner::E2F003 => "2F003", + Inner::E2F004 => "2F004", + Inner::E34000 => "34000", + Inner::E38000 => "38000", + Inner::E38001 => "38001", + Inner::E38002 => "38002", + Inner::E38003 => "38003", + Inner::E38004 => "38004", + Inner::E39000 => "39000", + Inner::E39001 => "39001", + Inner::E39004 => "39004", + Inner::E39P01 => "39P01", + Inner::E39P02 => "39P02", + Inner::E39P03 => "39P03", + Inner::E3B000 => "3B000", + Inner::E3B001 => "3B001", + Inner::E3D000 => "3D000", + Inner::E3F000 => "3F000", + Inner::E40000 => "40000", + Inner::E40002 => "40002", + Inner::E40001 => "40001", + Inner::E40003 => "40003", + Inner::E40P01 => "40P01", + Inner::E42000 => "42000", + Inner::E42601 => "42601", + Inner::E42501 => "42501", + Inner::E42846 => "42846", + Inner::E42803 => "42803", + Inner::E42P20 => "42P20", + Inner::E42P19 => "42P19", + Inner::E42830 => "42830", + Inner::E42602 => "42602", + Inner::E42622 => "42622", + Inner::E42939 => "42939", + Inner::E42804 => "42804", + Inner::E42P18 => "42P18", + Inner::E42P21 => "42P21", + Inner::E42P22 => "42P22", + Inner::E42809 => "42809", + Inner::E428C9 => "428C9", + Inner::E42703 => "42703", + Inner::E42883 => "42883", + Inner::E42P01 => "42P01", + Inner::E42P02 => "42P02", + Inner::E42704 => "42704", + Inner::E42701 => "42701", + Inner::E42P03 => "42P03", + Inner::E42P04 => "42P04", + Inner::E42723 => "42723", + Inner::E42P05 => "42P05", + Inner::E42P06 => "42P06", + Inner::E42P07 => "42P07", + Inner::E42712 => "42712", + Inner::E42710 => "42710", + Inner::E42702 => "42702", + Inner::E42725 => "42725", + Inner::E42P08 => "42P08", + Inner::E42P09 => "42P09", + Inner::E42P10 => "42P10", + Inner::E42611 => "42611", + Inner::E42P11 => "42P11", + Inner::E42P12 => "42P12", + Inner::E42P13 => "42P13", + Inner::E42P14 => "42P14", + Inner::E42P15 => "42P15", + Inner::E42P16 => "42P16", + Inner::E42P17 => "42P17", + Inner::E44000 => "44000", + Inner::E53000 => "53000", + Inner::E53100 => "53100", + Inner::E53200 => "53200", + Inner::E53300 => "53300", + Inner::E53400 => "53400", + Inner::E54000 => "54000", + Inner::E54001 => "54001", + Inner::E54011 => "54011", + Inner::E54023 => "54023", + Inner::E55000 => "55000", + Inner::E55006 => "55006", + Inner::E55P02 => "55P02", + Inner::E55P03 => "55P03", + Inner::E55P04 => "55P04", + Inner::E57000 => "57000", + Inner::E57014 => "57014", + Inner::E57P01 => "57P01", + Inner::E57P02 => "57P02", + Inner::E57P03 => "57P03", + Inner::E57P04 => "57P04", + Inner::E57P05 => "57P05", + Inner::E58000 => "58000", + Inner::E58030 => "58030", + Inner::E58P01 => "58P01", + Inner::E58P02 => "58P02", + Inner::E72000 => "72000", + Inner::EF0000 => "F0000", + Inner::EF0001 => "F0001", + Inner::EHV000 => "HV000", + Inner::EHV005 => "HV005", + Inner::EHV002 => "HV002", + Inner::EHV010 => "HV010", + Inner::EHV021 => "HV021", + Inner::EHV024 => "HV024", + Inner::EHV007 => "HV007", + Inner::EHV008 => "HV008", + Inner::EHV004 => "HV004", + Inner::EHV006 => "HV006", + Inner::EHV091 => "HV091", + Inner::EHV00B => "HV00B", + Inner::EHV00C => "HV00C", + Inner::EHV00D => "HV00D", + Inner::EHV090 => "HV090", + Inner::EHV00A => "HV00A", + Inner::EHV009 => "HV009", + Inner::EHV014 => "HV014", + Inner::EHV001 => "HV001", + Inner::EHV00P => "HV00P", + Inner::EHV00J => "HV00J", + Inner::EHV00K => "HV00K", + Inner::EHV00Q => "HV00Q", + Inner::EHV00R => "HV00R", + Inner::EHV00L => "HV00L", + Inner::EHV00M => "HV00M", + Inner::EHV00N => "HV00N", + Inner::EP0000 => "P0000", + Inner::EP0001 => "P0001", + Inner::EP0002 => "P0002", + Inner::EP0003 => "P0003", + Inner::EP0004 => "P0004", + Inner::EXX000 => "XX000", + Inner::EXX001 => "XX001", + Inner::EXX002 => "XX002", + Inner::Other(code) => code, + } + } + + /// 00000 + pub const SUCCESSFUL_COMPLETION: SqlState = SqlState(Inner::E00000); + + /// 01000 + pub const WARNING: SqlState = SqlState(Inner::E01000); + + /// 0100C + pub const WARNING_DYNAMIC_RESULT_SETS_RETURNED: SqlState = SqlState(Inner::E0100C); + + /// 01008 + pub const WARNING_IMPLICIT_ZERO_BIT_PADDING: SqlState = SqlState(Inner::E01008); + + /// 01003 + pub const WARNING_NULL_VALUE_ELIMINATED_IN_SET_FUNCTION: SqlState = SqlState(Inner::E01003); + + /// 01007 + pub const WARNING_PRIVILEGE_NOT_GRANTED: SqlState = SqlState(Inner::E01007); + + /// 01006 + pub const WARNING_PRIVILEGE_NOT_REVOKED: SqlState = SqlState(Inner::E01006); + + /// 01004 + pub const WARNING_STRING_DATA_RIGHT_TRUNCATION: SqlState = SqlState(Inner::E01004); + + /// 01P01 + pub const WARNING_DEPRECATED_FEATURE: SqlState = SqlState(Inner::E01P01); + + /// 02000 + pub const NO_DATA: SqlState = SqlState(Inner::E02000); + + /// 02001 + pub const NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED: SqlState = SqlState(Inner::E02001); + + /// 03000 + pub const SQL_STATEMENT_NOT_YET_COMPLETE: SqlState = SqlState(Inner::E03000); + + /// 08000 + pub const CONNECTION_EXCEPTION: SqlState = SqlState(Inner::E08000); + + /// 08003 + pub const CONNECTION_DOES_NOT_EXIST: SqlState = SqlState(Inner::E08003); + + /// 08006 + pub const CONNECTION_FAILURE: SqlState = SqlState(Inner::E08006); + + /// 08001 + pub const SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: SqlState = SqlState(Inner::E08001); + + /// 08004 + pub const SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION: SqlState = SqlState(Inner::E08004); + + /// 08007 + pub const TRANSACTION_RESOLUTION_UNKNOWN: SqlState = SqlState(Inner::E08007); + + /// 08P01 + pub const PROTOCOL_VIOLATION: SqlState = SqlState(Inner::E08P01); + + /// 09000 + pub const TRIGGERED_ACTION_EXCEPTION: SqlState = SqlState(Inner::E09000); + + /// 0A000 + pub const FEATURE_NOT_SUPPORTED: SqlState = SqlState(Inner::E0A000); + + /// 0B000 + pub const INVALID_TRANSACTION_INITIATION: SqlState = SqlState(Inner::E0B000); + + /// 0F000 + pub const LOCATOR_EXCEPTION: SqlState = SqlState(Inner::E0F000); + + /// 0F001 + pub const L_E_INVALID_SPECIFICATION: SqlState = SqlState(Inner::E0F001); + + /// 0L000 + pub const INVALID_GRANTOR: SqlState = SqlState(Inner::E0L000); + + /// 0LP01 + pub const INVALID_GRANT_OPERATION: SqlState = SqlState(Inner::E0LP01); + + /// 0P000 + pub const INVALID_ROLE_SPECIFICATION: SqlState = SqlState(Inner::E0P000); + + /// 0Z000 + pub const DIAGNOSTICS_EXCEPTION: SqlState = SqlState(Inner::E0Z000); + + /// 0Z002 + pub const STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER: SqlState = + SqlState(Inner::E0Z002); + + /// 20000 + pub const CASE_NOT_FOUND: SqlState = SqlState(Inner::E20000); + + /// 21000 + pub const CARDINALITY_VIOLATION: SqlState = SqlState(Inner::E21000); + + /// 22000 + pub const DATA_EXCEPTION: SqlState = SqlState(Inner::E22000); + + /// 2202E + pub const ARRAY_ELEMENT_ERROR: SqlState = SqlState(Inner::E2202E); + + /// 2202E + pub const ARRAY_SUBSCRIPT_ERROR: SqlState = SqlState(Inner::E2202E); + + /// 22021 + pub const CHARACTER_NOT_IN_REPERTOIRE: SqlState = SqlState(Inner::E22021); + + /// 22008 + pub const DATETIME_FIELD_OVERFLOW: SqlState = SqlState(Inner::E22008); + + /// 22008 + pub const DATETIME_VALUE_OUT_OF_RANGE: SqlState = SqlState(Inner::E22008); + + /// 22012 + pub const DIVISION_BY_ZERO: SqlState = SqlState(Inner::E22012); + + /// 22005 + pub const ERROR_IN_ASSIGNMENT: SqlState = SqlState(Inner::E22005); + + /// 2200B + pub const ESCAPE_CHARACTER_CONFLICT: SqlState = SqlState(Inner::E2200B); + + /// 22022 + pub const INDICATOR_OVERFLOW: SqlState = SqlState(Inner::E22022); + + /// 22015 + pub const INTERVAL_FIELD_OVERFLOW: SqlState = SqlState(Inner::E22015); + + /// 2201E + pub const INVALID_ARGUMENT_FOR_LOG: SqlState = SqlState(Inner::E2201E); + + /// 22014 + pub const INVALID_ARGUMENT_FOR_NTILE: SqlState = SqlState(Inner::E22014); + + /// 22016 + pub const INVALID_ARGUMENT_FOR_NTH_VALUE: SqlState = SqlState(Inner::E22016); + + /// 2201F + pub const INVALID_ARGUMENT_FOR_POWER_FUNCTION: SqlState = SqlState(Inner::E2201F); + + /// 2201G + pub const INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION: SqlState = SqlState(Inner::E2201G); + + /// 22018 + pub const INVALID_CHARACTER_VALUE_FOR_CAST: SqlState = SqlState(Inner::E22018); + + /// 22007 + pub const INVALID_DATETIME_FORMAT: SqlState = SqlState(Inner::E22007); + + /// 22019 + pub const INVALID_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E22019); + + /// 2200D + pub const INVALID_ESCAPE_OCTET: SqlState = SqlState(Inner::E2200D); + + /// 22025 + pub const INVALID_ESCAPE_SEQUENCE: SqlState = SqlState(Inner::E22025); + + /// 22P06 + pub const NONSTANDARD_USE_OF_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E22P06); + + /// 22010 + pub const INVALID_INDICATOR_PARAMETER_VALUE: SqlState = SqlState(Inner::E22010); + + /// 22023 + pub const INVALID_PARAMETER_VALUE: SqlState = SqlState(Inner::E22023); + + /// 22013 + pub const INVALID_PRECEDING_OR_FOLLOWING_SIZE: SqlState = SqlState(Inner::E22013); + + /// 2201B + pub const INVALID_REGULAR_EXPRESSION: SqlState = SqlState(Inner::E2201B); + + /// 2201W + pub const INVALID_ROW_COUNT_IN_LIMIT_CLAUSE: SqlState = SqlState(Inner::E2201W); + + /// 2201X + pub const INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE: SqlState = SqlState(Inner::E2201X); + + /// 2202H + pub const INVALID_TABLESAMPLE_ARGUMENT: SqlState = SqlState(Inner::E2202H); + + /// 2202G + pub const INVALID_TABLESAMPLE_REPEAT: SqlState = SqlState(Inner::E2202G); + + /// 22009 + pub const INVALID_TIME_ZONE_DISPLACEMENT_VALUE: SqlState = SqlState(Inner::E22009); + + /// 2200C + pub const INVALID_USE_OF_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E2200C); + + /// 2200G + pub const MOST_SPECIFIC_TYPE_MISMATCH: SqlState = SqlState(Inner::E2200G); + + /// 22004 + pub const NULL_VALUE_NOT_ALLOWED: SqlState = SqlState(Inner::E22004); + + /// 22002 + pub const NULL_VALUE_NO_INDICATOR_PARAMETER: SqlState = SqlState(Inner::E22002); + + /// 22003 + pub const NUMERIC_VALUE_OUT_OF_RANGE: SqlState = SqlState(Inner::E22003); + + /// 2200H + pub const SEQUENCE_GENERATOR_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E2200H); + + /// 22026 + pub const STRING_DATA_LENGTH_MISMATCH: SqlState = SqlState(Inner::E22026); + + /// 22001 + pub const STRING_DATA_RIGHT_TRUNCATION: SqlState = SqlState(Inner::E22001); + + /// 22011 + pub const SUBSTRING_ERROR: SqlState = SqlState(Inner::E22011); + + /// 22027 + pub const TRIM_ERROR: SqlState = SqlState(Inner::E22027); + + /// 22024 + pub const UNTERMINATED_C_STRING: SqlState = SqlState(Inner::E22024); + + /// 2200F + pub const ZERO_LENGTH_CHARACTER_STRING: SqlState = SqlState(Inner::E2200F); + + /// 22P01 + pub const FLOATING_POINT_EXCEPTION: SqlState = SqlState(Inner::E22P01); + + /// 22P02 + pub const INVALID_TEXT_REPRESENTATION: SqlState = SqlState(Inner::E22P02); + + /// 22P03 + pub const INVALID_BINARY_REPRESENTATION: SqlState = SqlState(Inner::E22P03); + + /// 22P04 + pub const BAD_COPY_FILE_FORMAT: SqlState = SqlState(Inner::E22P04); + + /// 22P05 + pub const UNTRANSLATABLE_CHARACTER: SqlState = SqlState(Inner::E22P05); + + /// 2200L + pub const NOT_AN_XML_DOCUMENT: SqlState = SqlState(Inner::E2200L); + + /// 2200M + pub const INVALID_XML_DOCUMENT: SqlState = SqlState(Inner::E2200M); + + /// 2200N + pub const INVALID_XML_CONTENT: SqlState = SqlState(Inner::E2200N); + + /// 2200S + pub const INVALID_XML_COMMENT: SqlState = SqlState(Inner::E2200S); + + /// 2200T + pub const INVALID_XML_PROCESSING_INSTRUCTION: SqlState = SqlState(Inner::E2200T); + + /// 22030 + pub const DUPLICATE_JSON_OBJECT_KEY_VALUE: SqlState = SqlState(Inner::E22030); + + /// 22031 + pub const INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION: SqlState = SqlState(Inner::E22031); + + /// 22032 + pub const INVALID_JSON_TEXT: SqlState = SqlState(Inner::E22032); + + /// 22033 + pub const INVALID_SQL_JSON_SUBSCRIPT: SqlState = SqlState(Inner::E22033); + + /// 22034 + pub const MORE_THAN_ONE_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22034); + + /// 22035 + pub const NO_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22035); + + /// 22036 + pub const NON_NUMERIC_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22036); + + /// 22037 + pub const NON_UNIQUE_KEYS_IN_A_JSON_OBJECT: SqlState = SqlState(Inner::E22037); + + /// 22038 + pub const SINGLETON_SQL_JSON_ITEM_REQUIRED: SqlState = SqlState(Inner::E22038); + + /// 22039 + pub const SQL_JSON_ARRAY_NOT_FOUND: SqlState = SqlState(Inner::E22039); + + /// 2203A + pub const SQL_JSON_MEMBER_NOT_FOUND: SqlState = SqlState(Inner::E2203A); + + /// 2203B + pub const SQL_JSON_NUMBER_NOT_FOUND: SqlState = SqlState(Inner::E2203B); + + /// 2203C + pub const SQL_JSON_OBJECT_NOT_FOUND: SqlState = SqlState(Inner::E2203C); + + /// 2203D + pub const TOO_MANY_JSON_ARRAY_ELEMENTS: SqlState = SqlState(Inner::E2203D); + + /// 2203E + pub const TOO_MANY_JSON_OBJECT_MEMBERS: SqlState = SqlState(Inner::E2203E); + + /// 2203F + pub const SQL_JSON_SCALAR_REQUIRED: SqlState = SqlState(Inner::E2203F); + + /// 2203G + pub const SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE: SqlState = SqlState(Inner::E2203G); + + /// 23000 + pub const INTEGRITY_CONSTRAINT_VIOLATION: SqlState = SqlState(Inner::E23000); + + /// 23001 + pub const RESTRICT_VIOLATION: SqlState = SqlState(Inner::E23001); + + /// 23502 + pub const NOT_NULL_VIOLATION: SqlState = SqlState(Inner::E23502); + + /// 23503 + pub const FOREIGN_KEY_VIOLATION: SqlState = SqlState(Inner::E23503); + + /// 23505 + pub const UNIQUE_VIOLATION: SqlState = SqlState(Inner::E23505); + + /// 23514 + pub const CHECK_VIOLATION: SqlState = SqlState(Inner::E23514); + + /// 23P01 + pub const EXCLUSION_VIOLATION: SqlState = SqlState(Inner::E23P01); + + /// 24000 + pub const INVALID_CURSOR_STATE: SqlState = SqlState(Inner::E24000); + + /// 25000 + pub const INVALID_TRANSACTION_STATE: SqlState = SqlState(Inner::E25000); + + /// 25001 + pub const ACTIVE_SQL_TRANSACTION: SqlState = SqlState(Inner::E25001); + + /// 25002 + pub const BRANCH_TRANSACTION_ALREADY_ACTIVE: SqlState = SqlState(Inner::E25002); + + /// 25008 + pub const HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL: SqlState = SqlState(Inner::E25008); + + /// 25003 + pub const INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Inner::E25003); + + /// 25004 + pub const INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION: SqlState = + SqlState(Inner::E25004); + + /// 25005 + pub const NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Inner::E25005); + + /// 25006 + pub const READ_ONLY_SQL_TRANSACTION: SqlState = SqlState(Inner::E25006); + + /// 25007 + pub const SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED: SqlState = SqlState(Inner::E25007); + + /// 25P01 + pub const NO_ACTIVE_SQL_TRANSACTION: SqlState = SqlState(Inner::E25P01); + + /// 25P02 + pub const IN_FAILED_SQL_TRANSACTION: SqlState = SqlState(Inner::E25P02); + + /// 25P03 + pub const IDLE_IN_TRANSACTION_SESSION_TIMEOUT: SqlState = SqlState(Inner::E25P03); + + /// 26000 + pub const INVALID_SQL_STATEMENT_NAME: SqlState = SqlState(Inner::E26000); + + /// 26000 + pub const UNDEFINED_PSTATEMENT: SqlState = SqlState(Inner::E26000); + + /// 27000 + pub const TRIGGERED_DATA_CHANGE_VIOLATION: SqlState = SqlState(Inner::E27000); + + /// 28000 + pub const INVALID_AUTHORIZATION_SPECIFICATION: SqlState = SqlState(Inner::E28000); + + /// 28P01 + pub const INVALID_PASSWORD: SqlState = SqlState(Inner::E28P01); + + /// 2B000 + pub const DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST: SqlState = SqlState(Inner::E2B000); + + /// 2BP01 + pub const DEPENDENT_OBJECTS_STILL_EXIST: SqlState = SqlState(Inner::E2BP01); + + /// 2D000 + pub const INVALID_TRANSACTION_TERMINATION: SqlState = SqlState(Inner::E2D000); + + /// 2F000 + pub const SQL_ROUTINE_EXCEPTION: SqlState = SqlState(Inner::E2F000); + + /// 2F005 + pub const S_R_E_FUNCTION_EXECUTED_NO_RETURN_STATEMENT: SqlState = SqlState(Inner::E2F005); + + /// 2F002 + pub const S_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E2F002); + + /// 2F003 + pub const S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED: SqlState = SqlState(Inner::E2F003); + + /// 2F004 + pub const S_R_E_READING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E2F004); + + /// 34000 + pub const INVALID_CURSOR_NAME: SqlState = SqlState(Inner::E34000); + + /// 34000 + pub const UNDEFINED_CURSOR: SqlState = SqlState(Inner::E34000); + + /// 38000 + pub const EXTERNAL_ROUTINE_EXCEPTION: SqlState = SqlState(Inner::E38000); + + /// 38001 + pub const E_R_E_CONTAINING_SQL_NOT_PERMITTED: SqlState = SqlState(Inner::E38001); + + /// 38002 + pub const E_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E38002); + + /// 38003 + pub const E_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED: SqlState = SqlState(Inner::E38003); + + /// 38004 + pub const E_R_E_READING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E38004); + + /// 39000 + pub const EXTERNAL_ROUTINE_INVOCATION_EXCEPTION: SqlState = SqlState(Inner::E39000); + + /// 39001 + pub const E_R_I_E_INVALID_SQLSTATE_RETURNED: SqlState = SqlState(Inner::E39001); + + /// 39004 + pub const E_R_I_E_NULL_VALUE_NOT_ALLOWED: SqlState = SqlState(Inner::E39004); + + /// 39P01 + pub const E_R_I_E_TRIGGER_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P01); + + /// 39P02 + pub const E_R_I_E_SRF_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P02); + + /// 39P03 + pub const E_R_I_E_EVENT_TRIGGER_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P03); + + /// 3B000 + pub const SAVEPOINT_EXCEPTION: SqlState = SqlState(Inner::E3B000); + + /// 3B001 + pub const S_E_INVALID_SPECIFICATION: SqlState = SqlState(Inner::E3B001); + + /// 3D000 + pub const INVALID_CATALOG_NAME: SqlState = SqlState(Inner::E3D000); + + /// 3D000 + pub const UNDEFINED_DATABASE: SqlState = SqlState(Inner::E3D000); + + /// 3F000 + pub const INVALID_SCHEMA_NAME: SqlState = SqlState(Inner::E3F000); + + /// 3F000 + pub const UNDEFINED_SCHEMA: SqlState = SqlState(Inner::E3F000); + + /// 40000 + pub const TRANSACTION_ROLLBACK: SqlState = SqlState(Inner::E40000); + + /// 40002 + pub const T_R_INTEGRITY_CONSTRAINT_VIOLATION: SqlState = SqlState(Inner::E40002); + + /// 40001 + pub const T_R_SERIALIZATION_FAILURE: SqlState = SqlState(Inner::E40001); + + /// 40003 + pub const T_R_STATEMENT_COMPLETION_UNKNOWN: SqlState = SqlState(Inner::E40003); + + /// 40P01 + pub const T_R_DEADLOCK_DETECTED: SqlState = SqlState(Inner::E40P01); + + /// 42000 + pub const SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION: SqlState = SqlState(Inner::E42000); + + /// 42601 + pub const SYNTAX_ERROR: SqlState = SqlState(Inner::E42601); + + /// 42501 + pub const INSUFFICIENT_PRIVILEGE: SqlState = SqlState(Inner::E42501); + + /// 42846 + pub const CANNOT_COERCE: SqlState = SqlState(Inner::E42846); + + /// 42803 + pub const GROUPING_ERROR: SqlState = SqlState(Inner::E42803); + + /// 42P20 + pub const WINDOWING_ERROR: SqlState = SqlState(Inner::E42P20); + + /// 42P19 + pub const INVALID_RECURSION: SqlState = SqlState(Inner::E42P19); + + /// 42830 + pub const INVALID_FOREIGN_KEY: SqlState = SqlState(Inner::E42830); + + /// 42602 + pub const INVALID_NAME: SqlState = SqlState(Inner::E42602); + + /// 42622 + pub const NAME_TOO_LONG: SqlState = SqlState(Inner::E42622); + + /// 42939 + pub const RESERVED_NAME: SqlState = SqlState(Inner::E42939); + + /// 42804 + pub const DATATYPE_MISMATCH: SqlState = SqlState(Inner::E42804); + + /// 42P18 + pub const INDETERMINATE_DATATYPE: SqlState = SqlState(Inner::E42P18); + + /// 42P21 + pub const COLLATION_MISMATCH: SqlState = SqlState(Inner::E42P21); + + /// 42P22 + pub const INDETERMINATE_COLLATION: SqlState = SqlState(Inner::E42P22); + + /// 42809 + pub const WRONG_OBJECT_TYPE: SqlState = SqlState(Inner::E42809); + + /// 428C9 + pub const GENERATED_ALWAYS: SqlState = SqlState(Inner::E428C9); + + /// 42703 + pub const UNDEFINED_COLUMN: SqlState = SqlState(Inner::E42703); + + /// 42883 + pub const UNDEFINED_FUNCTION: SqlState = SqlState(Inner::E42883); + + /// 42P01 + pub const UNDEFINED_TABLE: SqlState = SqlState(Inner::E42P01); + + /// 42P02 + pub const UNDEFINED_PARAMETER: SqlState = SqlState(Inner::E42P02); + + /// 42704 + pub const UNDEFINED_OBJECT: SqlState = SqlState(Inner::E42704); + + /// 42701 + pub const DUPLICATE_COLUMN: SqlState = SqlState(Inner::E42701); + + /// 42P03 + pub const DUPLICATE_CURSOR: SqlState = SqlState(Inner::E42P03); + + /// 42P04 + pub const DUPLICATE_DATABASE: SqlState = SqlState(Inner::E42P04); + + /// 42723 + pub const DUPLICATE_FUNCTION: SqlState = SqlState(Inner::E42723); + + /// 42P05 + pub const DUPLICATE_PSTATEMENT: SqlState = SqlState(Inner::E42P05); + + /// 42P06 + pub const DUPLICATE_SCHEMA: SqlState = SqlState(Inner::E42P06); + + /// 42P07 + pub const DUPLICATE_TABLE: SqlState = SqlState(Inner::E42P07); + + /// 42712 + pub const DUPLICATE_ALIAS: SqlState = SqlState(Inner::E42712); + + /// 42710 + pub const DUPLICATE_OBJECT: SqlState = SqlState(Inner::E42710); + + /// 42702 + pub const AMBIGUOUS_COLUMN: SqlState = SqlState(Inner::E42702); + + /// 42725 + pub const AMBIGUOUS_FUNCTION: SqlState = SqlState(Inner::E42725); + + /// 42P08 + pub const AMBIGUOUS_PARAMETER: SqlState = SqlState(Inner::E42P08); + + /// 42P09 + pub const AMBIGUOUS_ALIAS: SqlState = SqlState(Inner::E42P09); + + /// 42P10 + pub const INVALID_COLUMN_REFERENCE: SqlState = SqlState(Inner::E42P10); + + /// 42611 + pub const INVALID_COLUMN_DEFINITION: SqlState = SqlState(Inner::E42611); + + /// 42P11 + pub const INVALID_CURSOR_DEFINITION: SqlState = SqlState(Inner::E42P11); + + /// 42P12 + pub const INVALID_DATABASE_DEFINITION: SqlState = SqlState(Inner::E42P12); + + /// 42P13 + pub const INVALID_FUNCTION_DEFINITION: SqlState = SqlState(Inner::E42P13); + + /// 42P14 + pub const INVALID_PSTATEMENT_DEFINITION: SqlState = SqlState(Inner::E42P14); + + /// 42P15 + pub const INVALID_SCHEMA_DEFINITION: SqlState = SqlState(Inner::E42P15); + + /// 42P16 + pub const INVALID_TABLE_DEFINITION: SqlState = SqlState(Inner::E42P16); + + /// 42P17 + pub const INVALID_OBJECT_DEFINITION: SqlState = SqlState(Inner::E42P17); + + /// 44000 + pub const WITH_CHECK_OPTION_VIOLATION: SqlState = SqlState(Inner::E44000); + + /// 53000 + pub const INSUFFICIENT_RESOURCES: SqlState = SqlState(Inner::E53000); + + /// 53100 + pub const DISK_FULL: SqlState = SqlState(Inner::E53100); + + /// 53200 + pub const OUT_OF_MEMORY: SqlState = SqlState(Inner::E53200); + + /// 53300 + pub const TOO_MANY_CONNECTIONS: SqlState = SqlState(Inner::E53300); + + /// 53400 + pub const CONFIGURATION_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E53400); + + /// 54000 + pub const PROGRAM_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E54000); + + /// 54001 + pub const STATEMENT_TOO_COMPLEX: SqlState = SqlState(Inner::E54001); + + /// 54011 + pub const TOO_MANY_COLUMNS: SqlState = SqlState(Inner::E54011); + + /// 54023 + pub const TOO_MANY_ARGUMENTS: SqlState = SqlState(Inner::E54023); + + /// 55000 + pub const OBJECT_NOT_IN_PREREQUISITE_STATE: SqlState = SqlState(Inner::E55000); + + /// 55006 + pub const OBJECT_IN_USE: SqlState = SqlState(Inner::E55006); + + /// 55P02 + pub const CANT_CHANGE_RUNTIME_PARAM: SqlState = SqlState(Inner::E55P02); + + /// 55P03 + pub const LOCK_NOT_AVAILABLE: SqlState = SqlState(Inner::E55P03); + + /// 55P04 + pub const UNSAFE_NEW_ENUM_VALUE_USAGE: SqlState = SqlState(Inner::E55P04); + + /// 57000 + pub const OPERATOR_INTERVENTION: SqlState = SqlState(Inner::E57000); + + /// 57014 + pub const QUERY_CANCELED: SqlState = SqlState(Inner::E57014); + + /// 57P01 + pub const ADMIN_SHUTDOWN: SqlState = SqlState(Inner::E57P01); + + /// 57P02 + pub const CRASH_SHUTDOWN: SqlState = SqlState(Inner::E57P02); + + /// 57P03 + pub const CANNOT_CONNECT_NOW: SqlState = SqlState(Inner::E57P03); + + /// 57P04 + pub const DATABASE_DROPPED: SqlState = SqlState(Inner::E57P04); + + /// 57P05 + pub const IDLE_SESSION_TIMEOUT: SqlState = SqlState(Inner::E57P05); + + /// 58000 + pub const SYSTEM_ERROR: SqlState = SqlState(Inner::E58000); + + /// 58030 + pub const IO_ERROR: SqlState = SqlState(Inner::E58030); + + /// 58P01 + pub const UNDEFINED_FILE: SqlState = SqlState(Inner::E58P01); + + /// 58P02 + pub const DUPLICATE_FILE: SqlState = SqlState(Inner::E58P02); + + /// 72000 + pub const SNAPSHOT_TOO_OLD: SqlState = SqlState(Inner::E72000); + + /// F0000 + pub const CONFIG_FILE_ERROR: SqlState = SqlState(Inner::EF0000); + + /// F0001 + pub const LOCK_FILE_EXISTS: SqlState = SqlState(Inner::EF0001); + + /// HV000 + pub const FDW_ERROR: SqlState = SqlState(Inner::EHV000); + + /// HV005 + pub const FDW_COLUMN_NAME_NOT_FOUND: SqlState = SqlState(Inner::EHV005); + + /// HV002 + pub const FDW_DYNAMIC_PARAMETER_VALUE_NEEDED: SqlState = SqlState(Inner::EHV002); + + /// HV010 + pub const FDW_FUNCTION_SEQUENCE_ERROR: SqlState = SqlState(Inner::EHV010); + + /// HV021 + pub const FDW_INCONSISTENT_DESCRIPTOR_INFORMATION: SqlState = SqlState(Inner::EHV021); + + /// HV024 + pub const FDW_INVALID_ATTRIBUTE_VALUE: SqlState = SqlState(Inner::EHV024); + + /// HV007 + pub const FDW_INVALID_COLUMN_NAME: SqlState = SqlState(Inner::EHV007); + + /// HV008 + pub const FDW_INVALID_COLUMN_NUMBER: SqlState = SqlState(Inner::EHV008); + + /// HV004 + pub const FDW_INVALID_DATA_TYPE: SqlState = SqlState(Inner::EHV004); + + /// HV006 + pub const FDW_INVALID_DATA_TYPE_DESCRIPTORS: SqlState = SqlState(Inner::EHV006); + + /// HV091 + pub const FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER: SqlState = SqlState(Inner::EHV091); + + /// HV00B + pub const FDW_INVALID_HANDLE: SqlState = SqlState(Inner::EHV00B); + + /// HV00C + pub const FDW_INVALID_OPTION_INDEX: SqlState = SqlState(Inner::EHV00C); + + /// HV00D + pub const FDW_INVALID_OPTION_NAME: SqlState = SqlState(Inner::EHV00D); + + /// HV090 + pub const FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH: SqlState = SqlState(Inner::EHV090); + + /// HV00A + pub const FDW_INVALID_STRING_FORMAT: SqlState = SqlState(Inner::EHV00A); + + /// HV009 + pub const FDW_INVALID_USE_OF_NULL_POINTER: SqlState = SqlState(Inner::EHV009); + + /// HV014 + pub const FDW_TOO_MANY_HANDLES: SqlState = SqlState(Inner::EHV014); + + /// HV001 + pub const FDW_OUT_OF_MEMORY: SqlState = SqlState(Inner::EHV001); + + /// HV00P + pub const FDW_NO_SCHEMAS: SqlState = SqlState(Inner::EHV00P); + + /// HV00J + pub const FDW_OPTION_NAME_NOT_FOUND: SqlState = SqlState(Inner::EHV00J); + + /// HV00K + pub const FDW_REPLY_HANDLE: SqlState = SqlState(Inner::EHV00K); + + /// HV00Q + pub const FDW_SCHEMA_NOT_FOUND: SqlState = SqlState(Inner::EHV00Q); + + /// HV00R + pub const FDW_TABLE_NOT_FOUND: SqlState = SqlState(Inner::EHV00R); + + /// HV00L + pub const FDW_UNABLE_TO_CREATE_EXECUTION: SqlState = SqlState(Inner::EHV00L); + + /// HV00M + pub const FDW_UNABLE_TO_CREATE_REPLY: SqlState = SqlState(Inner::EHV00M); + + /// HV00N + pub const FDW_UNABLE_TO_ESTABLISH_CONNECTION: SqlState = SqlState(Inner::EHV00N); + + /// P0000 + pub const PLPGSQL_ERROR: SqlState = SqlState(Inner::EP0000); + + /// P0001 + pub const RAISE_EXCEPTION: SqlState = SqlState(Inner::EP0001); + + /// P0002 + pub const NO_DATA_FOUND: SqlState = SqlState(Inner::EP0002); + + /// P0003 + pub const TOO_MANY_ROWS: SqlState = SqlState(Inner::EP0003); + + /// P0004 + pub const ASSERT_FAILURE: SqlState = SqlState(Inner::EP0004); + + /// XX000 + pub const INTERNAL_ERROR: SqlState = SqlState(Inner::EXX000); + + /// XX001 + pub const DATA_CORRUPTED: SqlState = SqlState(Inner::EXX001); + + /// XX002 + pub const INDEX_CORRUPTED: SqlState = SqlState(Inner::EXX002); +} + +#[derive(PartialEq, Eq, Clone, Debug)] +#[allow(clippy::upper_case_acronyms)] +enum Inner { + E00000, + E01000, + E0100C, + E01008, + E01003, + E01007, + E01006, + E01004, + E01P01, + E02000, + E02001, + E03000, + E08000, + E08003, + E08006, + E08001, + E08004, + E08007, + E08P01, + E09000, + E0A000, + E0B000, + E0F000, + E0F001, + E0L000, + E0LP01, + E0P000, + E0Z000, + E0Z002, + E20000, + E21000, + E22000, + E2202E, + E22021, + E22008, + E22012, + E22005, + E2200B, + E22022, + E22015, + E2201E, + E22014, + E22016, + E2201F, + E2201G, + E22018, + E22007, + E22019, + E2200D, + E22025, + E22P06, + E22010, + E22023, + E22013, + E2201B, + E2201W, + E2201X, + E2202H, + E2202G, + E22009, + E2200C, + E2200G, + E22004, + E22002, + E22003, + E2200H, + E22026, + E22001, + E22011, + E22027, + E22024, + E2200F, + E22P01, + E22P02, + E22P03, + E22P04, + E22P05, + E2200L, + E2200M, + E2200N, + E2200S, + E2200T, + E22030, + E22031, + E22032, + E22033, + E22034, + E22035, + E22036, + E22037, + E22038, + E22039, + E2203A, + E2203B, + E2203C, + E2203D, + E2203E, + E2203F, + E2203G, + E23000, + E23001, + E23502, + E23503, + E23505, + E23514, + E23P01, + E24000, + E25000, + E25001, + E25002, + E25008, + E25003, + E25004, + E25005, + E25006, + E25007, + E25P01, + E25P02, + E25P03, + E26000, + E27000, + E28000, + E28P01, + E2B000, + E2BP01, + E2D000, + E2F000, + E2F005, + E2F002, + E2F003, + E2F004, + E34000, + E38000, + E38001, + E38002, + E38003, + E38004, + E39000, + E39001, + E39004, + E39P01, + E39P02, + E39P03, + E3B000, + E3B001, + E3D000, + E3F000, + E40000, + E40002, + E40001, + E40003, + E40P01, + E42000, + E42601, + E42501, + E42846, + E42803, + E42P20, + E42P19, + E42830, + E42602, + E42622, + E42939, + E42804, + E42P18, + E42P21, + E42P22, + E42809, + E428C9, + E42703, + E42883, + E42P01, + E42P02, + E42704, + E42701, + E42P03, + E42P04, + E42723, + E42P05, + E42P06, + E42P07, + E42712, + E42710, + E42702, + E42725, + E42P08, + E42P09, + E42P10, + E42611, + E42P11, + E42P12, + E42P13, + E42P14, + E42P15, + E42P16, + E42P17, + E44000, + E53000, + E53100, + E53200, + E53300, + E53400, + E54000, + E54001, + E54011, + E54023, + E55000, + E55006, + E55P02, + E55P03, + E55P04, + E57000, + E57014, + E57P01, + E57P02, + E57P03, + E57P04, + E57P05, + E58000, + E58030, + E58P01, + E58P02, + E72000, + EF0000, + EF0001, + EHV000, + EHV005, + EHV002, + EHV010, + EHV021, + EHV024, + EHV007, + EHV008, + EHV004, + EHV006, + EHV091, + EHV00B, + EHV00C, + EHV00D, + EHV090, + EHV00A, + EHV009, + EHV014, + EHV001, + EHV00P, + EHV00J, + EHV00K, + EHV00Q, + EHV00R, + EHV00L, + EHV00M, + EHV00N, + EP0000, + EP0001, + EP0002, + EP0003, + EP0004, + EXX000, + EXX001, + EXX002, + Other(Box), +} + +#[rustfmt::skip] +static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = +::phf::Map { + key: 12913932095322966823, + disps: &[ + (0, 24), + (0, 12), + (0, 74), + (0, 109), + (0, 11), + (0, 9), + (0, 0), + (4, 38), + (3, 155), + (0, 6), + (1, 242), + (0, 66), + (0, 53), + (5, 180), + (3, 221), + (7, 230), + (0, 125), + (1, 46), + (0, 11), + (1, 2), + (0, 5), + (0, 13), + (0, 171), + (0, 15), + (0, 4), + (0, 22), + (1, 85), + (0, 75), + (2, 0), + (1, 25), + (7, 47), + (0, 45), + (0, 35), + (0, 7), + (7, 124), + (0, 0), + (14, 104), + (1, 183), + (61, 50), + (3, 76), + (0, 12), + (0, 7), + (4, 189), + (0, 1), + (64, 102), + (0, 0), + (16, 192), + (24, 19), + (0, 5), + (0, 87), + (0, 89), + (0, 14), + ], + entries: &[ + ("2F000", SqlState::SQL_ROUTINE_EXCEPTION), + ("01008", SqlState::WARNING_IMPLICIT_ZERO_BIT_PADDING), + ("42501", SqlState::INSUFFICIENT_PRIVILEGE), + ("22000", SqlState::DATA_EXCEPTION), + ("0100C", SqlState::WARNING_DYNAMIC_RESULT_SETS_RETURNED), + ("2200N", SqlState::INVALID_XML_CONTENT), + ("40001", SqlState::T_R_SERIALIZATION_FAILURE), + ("28P01", SqlState::INVALID_PASSWORD), + ("38000", SqlState::EXTERNAL_ROUTINE_EXCEPTION), + ("25006", SqlState::READ_ONLY_SQL_TRANSACTION), + ("2203D", SqlState::TOO_MANY_JSON_ARRAY_ELEMENTS), + ("42P09", SqlState::AMBIGUOUS_ALIAS), + ("F0000", SqlState::CONFIG_FILE_ERROR), + ("42P18", SqlState::INDETERMINATE_DATATYPE), + ("40002", SqlState::T_R_INTEGRITY_CONSTRAINT_VIOLATION), + ("22009", SqlState::INVALID_TIME_ZONE_DISPLACEMENT_VALUE), + ("42P08", SqlState::AMBIGUOUS_PARAMETER), + ("08000", SqlState::CONNECTION_EXCEPTION), + ("25P01", SqlState::NO_ACTIVE_SQL_TRANSACTION), + ("22024", SqlState::UNTERMINATED_C_STRING), + ("55000", SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE), + ("25001", SqlState::ACTIVE_SQL_TRANSACTION), + ("03000", SqlState::SQL_STATEMENT_NOT_YET_COMPLETE), + ("42710", SqlState::DUPLICATE_OBJECT), + ("2D000", SqlState::INVALID_TRANSACTION_TERMINATION), + ("2200G", SqlState::MOST_SPECIFIC_TYPE_MISMATCH), + ("22022", SqlState::INDICATOR_OVERFLOW), + ("55006", SqlState::OBJECT_IN_USE), + ("53200", SqlState::OUT_OF_MEMORY), + ("22012", SqlState::DIVISION_BY_ZERO), + ("P0002", SqlState::NO_DATA_FOUND), + ("XX001", SqlState::DATA_CORRUPTED), + ("22P05", SqlState::UNTRANSLATABLE_CHARACTER), + ("40003", SqlState::T_R_STATEMENT_COMPLETION_UNKNOWN), + ("22021", SqlState::CHARACTER_NOT_IN_REPERTOIRE), + ("25000", SqlState::INVALID_TRANSACTION_STATE), + ("42P15", SqlState::INVALID_SCHEMA_DEFINITION), + ("0B000", SqlState::INVALID_TRANSACTION_INITIATION), + ("22004", SqlState::NULL_VALUE_NOT_ALLOWED), + ("42804", SqlState::DATATYPE_MISMATCH), + ("42803", SqlState::GROUPING_ERROR), + ("02001", SqlState::NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED), + ("25002", SqlState::BRANCH_TRANSACTION_ALREADY_ACTIVE), + ("28000", SqlState::INVALID_AUTHORIZATION_SPECIFICATION), + ("HV009", SqlState::FDW_INVALID_USE_OF_NULL_POINTER), + ("22P01", SqlState::FLOATING_POINT_EXCEPTION), + ("2B000", SqlState::DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST), + ("42723", SqlState::DUPLICATE_FUNCTION), + ("21000", SqlState::CARDINALITY_VIOLATION), + ("0Z002", SqlState::STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER), + ("23505", SqlState::UNIQUE_VIOLATION), + ("HV00J", SqlState::FDW_OPTION_NAME_NOT_FOUND), + ("23P01", SqlState::EXCLUSION_VIOLATION), + ("39P03", SqlState::E_R_I_E_EVENT_TRIGGER_PROTOCOL_VIOLATED), + ("42P10", SqlState::INVALID_COLUMN_REFERENCE), + ("2202H", SqlState::INVALID_TABLESAMPLE_ARGUMENT), + ("55P04", SqlState::UNSAFE_NEW_ENUM_VALUE_USAGE), + ("P0000", SqlState::PLPGSQL_ERROR), + ("2F005", SqlState::S_R_E_FUNCTION_EXECUTED_NO_RETURN_STATEMENT), + ("HV00M", SqlState::FDW_UNABLE_TO_CREATE_REPLY), + ("0A000", SqlState::FEATURE_NOT_SUPPORTED), + ("24000", SqlState::INVALID_CURSOR_STATE), + ("25008", SqlState::HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL), + ("01003", SqlState::WARNING_NULL_VALUE_ELIMINATED_IN_SET_FUNCTION), + ("42712", SqlState::DUPLICATE_ALIAS), + ("HV014", SqlState::FDW_TOO_MANY_HANDLES), + ("58030", SqlState::IO_ERROR), + ("2201W", SqlState::INVALID_ROW_COUNT_IN_LIMIT_CLAUSE), + ("22033", SqlState::INVALID_SQL_JSON_SUBSCRIPT), + ("2BP01", SqlState::DEPENDENT_OBJECTS_STILL_EXIST), + ("HV005", SqlState::FDW_COLUMN_NAME_NOT_FOUND), + ("25004", SqlState::INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION), + ("54000", SqlState::PROGRAM_LIMIT_EXCEEDED), + ("20000", SqlState::CASE_NOT_FOUND), + ("2203G", SqlState::SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE), + ("22038", SqlState::SINGLETON_SQL_JSON_ITEM_REQUIRED), + ("22007", SqlState::INVALID_DATETIME_FORMAT), + ("08004", SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION), + ("2200H", SqlState::SEQUENCE_GENERATOR_LIMIT_EXCEEDED), + ("HV00D", SqlState::FDW_INVALID_OPTION_NAME), + ("P0004", SqlState::ASSERT_FAILURE), + ("22018", SqlState::INVALID_CHARACTER_VALUE_FOR_CAST), + ("0L000", SqlState::INVALID_GRANTOR), + ("22P04", SqlState::BAD_COPY_FILE_FORMAT), + ("22031", SqlState::INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION), + ("01P01", SqlState::WARNING_DEPRECATED_FEATURE), + ("0LP01", SqlState::INVALID_GRANT_OPERATION), + ("58P02", SqlState::DUPLICATE_FILE), + ("26000", SqlState::INVALID_SQL_STATEMENT_NAME), + ("54001", SqlState::STATEMENT_TOO_COMPLEX), + ("22010", SqlState::INVALID_INDICATOR_PARAMETER_VALUE), + ("HV00C", SqlState::FDW_INVALID_OPTION_INDEX), + ("22008", SqlState::DATETIME_FIELD_OVERFLOW), + ("42P06", SqlState::DUPLICATE_SCHEMA), + ("25007", SqlState::SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED), + ("42P20", SqlState::WINDOWING_ERROR), + ("HV091", SqlState::FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER), + ("HV021", SqlState::FDW_INCONSISTENT_DESCRIPTOR_INFORMATION), + ("42702", SqlState::AMBIGUOUS_COLUMN), + ("02000", SqlState::NO_DATA), + ("54011", SqlState::TOO_MANY_COLUMNS), + ("HV004", SqlState::FDW_INVALID_DATA_TYPE), + ("01006", SqlState::WARNING_PRIVILEGE_NOT_REVOKED), + ("42701", SqlState::DUPLICATE_COLUMN), + ("08P01", SqlState::PROTOCOL_VIOLATION), + ("42622", SqlState::NAME_TOO_LONG), + ("P0003", SqlState::TOO_MANY_ROWS), + ("22003", SqlState::NUMERIC_VALUE_OUT_OF_RANGE), + ("42P03", SqlState::DUPLICATE_CURSOR), + ("23001", SqlState::RESTRICT_VIOLATION), + ("57000", SqlState::OPERATOR_INTERVENTION), + ("22027", SqlState::TRIM_ERROR), + ("42P12", SqlState::INVALID_DATABASE_DEFINITION), + ("3B000", SqlState::SAVEPOINT_EXCEPTION), + ("2201B", SqlState::INVALID_REGULAR_EXPRESSION), + ("22030", SqlState::DUPLICATE_JSON_OBJECT_KEY_VALUE), + ("2F004", SqlState::S_R_E_READING_SQL_DATA_NOT_PERMITTED), + ("428C9", SqlState::GENERATED_ALWAYS), + ("2200S", SqlState::INVALID_XML_COMMENT), + ("22039", SqlState::SQL_JSON_ARRAY_NOT_FOUND), + ("42809", SqlState::WRONG_OBJECT_TYPE), + ("2201X", SqlState::INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE), + ("39001", SqlState::E_R_I_E_INVALID_SQLSTATE_RETURNED), + ("25P02", SqlState::IN_FAILED_SQL_TRANSACTION), + ("0P000", SqlState::INVALID_ROLE_SPECIFICATION), + ("HV00N", SqlState::FDW_UNABLE_TO_ESTABLISH_CONNECTION), + ("53100", SqlState::DISK_FULL), + ("42601", SqlState::SYNTAX_ERROR), + ("23000", SqlState::INTEGRITY_CONSTRAINT_VIOLATION), + ("HV006", SqlState::FDW_INVALID_DATA_TYPE_DESCRIPTORS), + ("HV00B", SqlState::FDW_INVALID_HANDLE), + ("HV00Q", SqlState::FDW_SCHEMA_NOT_FOUND), + ("01000", SqlState::WARNING), + ("42883", SqlState::UNDEFINED_FUNCTION), + ("57P01", SqlState::ADMIN_SHUTDOWN), + ("22037", SqlState::NON_UNIQUE_KEYS_IN_A_JSON_OBJECT), + ("00000", SqlState::SUCCESSFUL_COMPLETION), + ("55P03", SqlState::LOCK_NOT_AVAILABLE), + ("42P01", SqlState::UNDEFINED_TABLE), + ("42830", SqlState::INVALID_FOREIGN_KEY), + ("22005", SqlState::ERROR_IN_ASSIGNMENT), + ("22025", SqlState::INVALID_ESCAPE_SEQUENCE), + ("XX002", SqlState::INDEX_CORRUPTED), + ("42P16", SqlState::INVALID_TABLE_DEFINITION), + ("55P02", SqlState::CANT_CHANGE_RUNTIME_PARAM), + ("22019", SqlState::INVALID_ESCAPE_CHARACTER), + ("P0001", SqlState::RAISE_EXCEPTION), + ("72000", SqlState::SNAPSHOT_TOO_OLD), + ("42P11", SqlState::INVALID_CURSOR_DEFINITION), + ("40P01", SqlState::T_R_DEADLOCK_DETECTED), + ("57P02", SqlState::CRASH_SHUTDOWN), + ("HV00A", SqlState::FDW_INVALID_STRING_FORMAT), + ("2F002", SqlState::S_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED), + ("23503", SqlState::FOREIGN_KEY_VIOLATION), + ("40000", SqlState::TRANSACTION_ROLLBACK), + ("22032", SqlState::INVALID_JSON_TEXT), + ("2202E", SqlState::ARRAY_ELEMENT_ERROR), + ("42P19", SqlState::INVALID_RECURSION), + ("42611", SqlState::INVALID_COLUMN_DEFINITION), + ("42P13", SqlState::INVALID_FUNCTION_DEFINITION), + ("25003", SqlState::INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION), + ("39P02", SqlState::E_R_I_E_SRF_PROTOCOL_VIOLATED), + ("XX000", SqlState::INTERNAL_ERROR), + ("08006", SqlState::CONNECTION_FAILURE), + ("57P04", SqlState::DATABASE_DROPPED), + ("42P07", SqlState::DUPLICATE_TABLE), + ("22P03", SqlState::INVALID_BINARY_REPRESENTATION), + ("22035", SqlState::NO_SQL_JSON_ITEM), + ("42P14", SqlState::INVALID_PSTATEMENT_DEFINITION), + ("01007", SqlState::WARNING_PRIVILEGE_NOT_GRANTED), + ("38004", SqlState::E_R_E_READING_SQL_DATA_NOT_PERMITTED), + ("42P21", SqlState::COLLATION_MISMATCH), + ("0Z000", SqlState::DIAGNOSTICS_EXCEPTION), + ("HV001", SqlState::FDW_OUT_OF_MEMORY), + ("0F000", SqlState::LOCATOR_EXCEPTION), + ("22013", SqlState::INVALID_PRECEDING_OR_FOLLOWING_SIZE), + ("2201E", SqlState::INVALID_ARGUMENT_FOR_LOG), + ("22011", SqlState::SUBSTRING_ERROR), + ("42602", SqlState::INVALID_NAME), + ("01004", SqlState::WARNING_STRING_DATA_RIGHT_TRUNCATION), + ("42P02", SqlState::UNDEFINED_PARAMETER), + ("2203C", SqlState::SQL_JSON_OBJECT_NOT_FOUND), + ("HV002", SqlState::FDW_DYNAMIC_PARAMETER_VALUE_NEEDED), + ("0F001", SqlState::L_E_INVALID_SPECIFICATION), + ("58P01", SqlState::UNDEFINED_FILE), + ("38001", SqlState::E_R_E_CONTAINING_SQL_NOT_PERMITTED), + ("42703", SqlState::UNDEFINED_COLUMN), + ("57P05", SqlState::IDLE_SESSION_TIMEOUT), + ("57P03", SqlState::CANNOT_CONNECT_NOW), + ("HV007", SqlState::FDW_INVALID_COLUMN_NAME), + ("22014", SqlState::INVALID_ARGUMENT_FOR_NTILE), + ("22P06", SqlState::NONSTANDARD_USE_OF_ESCAPE_CHARACTER), + ("2203F", SqlState::SQL_JSON_SCALAR_REQUIRED), + ("2200F", SqlState::ZERO_LENGTH_CHARACTER_STRING), + ("09000", SqlState::TRIGGERED_ACTION_EXCEPTION), + ("2201F", SqlState::INVALID_ARGUMENT_FOR_POWER_FUNCTION), + ("08003", SqlState::CONNECTION_DOES_NOT_EXIST), + ("38002", SqlState::E_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED), + ("F0001", SqlState::LOCK_FILE_EXISTS), + ("42P22", SqlState::INDETERMINATE_COLLATION), + ("2200C", SqlState::INVALID_USE_OF_ESCAPE_CHARACTER), + ("2203E", SqlState::TOO_MANY_JSON_OBJECT_MEMBERS), + ("23514", SqlState::CHECK_VIOLATION), + ("22P02", SqlState::INVALID_TEXT_REPRESENTATION), + ("54023", SqlState::TOO_MANY_ARGUMENTS), + ("2200T", SqlState::INVALID_XML_PROCESSING_INSTRUCTION), + ("22016", SqlState::INVALID_ARGUMENT_FOR_NTH_VALUE), + ("25P03", SqlState::IDLE_IN_TRANSACTION_SESSION_TIMEOUT), + ("3B001", SqlState::S_E_INVALID_SPECIFICATION), + ("08001", SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION), + ("22036", SqlState::NON_NUMERIC_SQL_JSON_ITEM), + ("3F000", SqlState::INVALID_SCHEMA_NAME), + ("39P01", SqlState::E_R_I_E_TRIGGER_PROTOCOL_VIOLATED), + ("22026", SqlState::STRING_DATA_LENGTH_MISMATCH), + ("42P17", SqlState::INVALID_OBJECT_DEFINITION), + ("22034", SqlState::MORE_THAN_ONE_SQL_JSON_ITEM), + ("HV000", SqlState::FDW_ERROR), + ("2200B", SqlState::ESCAPE_CHARACTER_CONFLICT), + ("HV008", SqlState::FDW_INVALID_COLUMN_NUMBER), + ("34000", SqlState::INVALID_CURSOR_NAME), + ("2201G", SqlState::INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION), + ("44000", SqlState::WITH_CHECK_OPTION_VIOLATION), + ("HV010", SqlState::FDW_FUNCTION_SEQUENCE_ERROR), + ("39004", SqlState::E_R_I_E_NULL_VALUE_NOT_ALLOWED), + ("22001", SqlState::STRING_DATA_RIGHT_TRUNCATION), + ("3D000", SqlState::INVALID_CATALOG_NAME), + ("25005", SqlState::NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION), + ("2200L", SqlState::NOT_AN_XML_DOCUMENT), + ("27000", SqlState::TRIGGERED_DATA_CHANGE_VIOLATION), + ("HV090", SqlState::FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH), + ("42939", SqlState::RESERVED_NAME), + ("58000", SqlState::SYSTEM_ERROR), + ("2200M", SqlState::INVALID_XML_DOCUMENT), + ("HV00L", SqlState::FDW_UNABLE_TO_CREATE_EXECUTION), + ("57014", SqlState::QUERY_CANCELED), + ("23502", SqlState::NOT_NULL_VIOLATION), + ("22002", SqlState::NULL_VALUE_NO_INDICATOR_PARAMETER), + ("HV00R", SqlState::FDW_TABLE_NOT_FOUND), + ("HV00P", SqlState::FDW_NO_SCHEMAS), + ("38003", SqlState::E_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED), + ("39000", SqlState::EXTERNAL_ROUTINE_INVOCATION_EXCEPTION), + ("22015", SqlState::INTERVAL_FIELD_OVERFLOW), + ("HV00K", SqlState::FDW_REPLY_HANDLE), + ("HV024", SqlState::FDW_INVALID_ATTRIBUTE_VALUE), + ("2200D", SqlState::INVALID_ESCAPE_OCTET), + ("08007", SqlState::TRANSACTION_RESOLUTION_UNKNOWN), + ("2F003", SqlState::S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED), + ("42725", SqlState::AMBIGUOUS_FUNCTION), + ("2203A", SqlState::SQL_JSON_MEMBER_NOT_FOUND), + ("42846", SqlState::CANNOT_COERCE), + ("42P04", SqlState::DUPLICATE_DATABASE), + ("42000", SqlState::SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION), + ("2203B", SqlState::SQL_JSON_NUMBER_NOT_FOUND), + ("42P05", SqlState::DUPLICATE_PSTATEMENT), + ("53300", SqlState::TOO_MANY_CONNECTIONS), + ("53400", SqlState::CONFIGURATION_LIMIT_EXCEEDED), + ("42704", SqlState::UNDEFINED_OBJECT), + ("2202G", SqlState::INVALID_TABLESAMPLE_REPEAT), + ("22023", SqlState::INVALID_PARAMETER_VALUE), + ("53000", SqlState::INSUFFICIENT_RESOURCES), + ], +}; diff --git a/crates/corro-types/Cargo.toml b/crates/corro-types/Cargo.toml index cd9bec9a..5c3649cf 100644 --- a/crates/corro-types/Cargo.toml +++ b/crates/corro-types/Cargo.toml @@ -20,12 +20,15 @@ enquote = { workspace = true } fallible-iterator = { workspace = true } foca = { workspace = true } futures = { workspace = true } +hyper = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } +rmp-serde = { workspace = true } metrics = { workspace = true } once_cell = { workspace = true } opentelemetry = { workspace = true } parking_lot = { workspace = true } +pin-project-lite = { workspace = true } rand = { workspace = true } rangemap = { workspace = true } rcgen = { workspace = true } diff --git a/crates/corro-types/src/agent.rs b/crates/corro-types/src/agent.rs index 2e7ca705..a339be9c 100644 --- a/crates/corro-types/src/agent.rs +++ b/crates/corro-types/src/agent.rs @@ -20,6 +20,10 @@ use parking_lot::RwLock; use rangemap::RangeInclusiveSet; use rusqlite::{Connection, InterruptHandle}; use serde::{Deserialize, Serialize}; +use tokio::sync::{ + OwnedRwLockWriteGuard as OwnedTokioRwLockWriteGuard, RwLock as TokioRwLock, + RwLockReadGuard as TokioRwLockReadGuard, RwLockWriteGuard as TokioRwLockWriteGuard, +}; use tokio::{ runtime::Handle, sync::{ @@ -27,13 +31,6 @@ use tokio::{ oneshot, Semaphore, }, }; -use tokio::{ - sync::{ - OwnedRwLockWriteGuard as OwnedTokioRwLockWriteGuard, RwLock as TokioRwLock, - RwLockReadGuard as TokioRwLockReadGuard, RwLockWriteGuard as TokioRwLockWriteGuard, - }, - task::block_in_place, -}; use tokio_util::sync::{CancellationToken, DropGuard}; use tracing::{debug, error, info, Instrument}; use tripwire::Tripwire; @@ -363,12 +360,16 @@ impl SplitPool { } #[tracing::instrument(skip(self), level = "debug")] - pub async fn dedicated(&self) -> rusqlite::Result { - block_in_place(|| { - let mut conn = rusqlite::Connection::open(&self.0.path)?; - setup_conn(&mut conn, &self.0.attachments)?; - Ok(conn) - }) + pub fn dedicated(&self) -> rusqlite::Result { + let mut conn = rusqlite::Connection::open(&self.0.path)?; + setup_conn(&mut conn, &self.0.attachments)?; + Ok(conn) + } + + #[tracing::instrument(skip(self), level = "debug")] + pub fn client_dedicated(&self) -> rusqlite::Result { + let conn = rusqlite::Connection::open(&self.0.path)?; + rusqlite_to_crsqlite(conn) } // get a high priority write connection (e.g. client input) diff --git a/crates/corro-types/src/config.rs b/crates/corro-types/src/config.rs index ef56afd6..f7586101 100644 --- a/crates/corro-types/src/config.rs +++ b/crates/corro-types/src/config.rs @@ -87,6 +87,14 @@ pub struct ApiConfig { pub bind_addr: SocketAddr, #[serde(alias = "authz", default)] pub authorization: Option, + #[serde(default)] + pub pg: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PgConfig { + #[serde(alias = "addr")] + pub bind_addr: SocketAddr, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -273,6 +281,7 @@ impl ConfigBuilder { api: ApiConfig { bind_addr: self.api_addr.ok_or(ConfigBuilderError::ApiAddrRequired)?, authorization: None, + pg: None, }, gossip: GossipConfig { bind_addr: self diff --git a/crates/corro-types/src/http.rs b/crates/corro-types/src/http.rs new file mode 100644 index 00000000..629ee594 --- /dev/null +++ b/crates/corro-types/src/http.rs @@ -0,0 +1,181 @@ +use std::{ + error::Error, + io, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use futures::{ready, Stream}; +use hyper::Body; +use pin_project_lite::pin_project; +use tokio_util::codec::{Decoder, Encoder, LinesCodecError}; + +pin_project! { + pub struct IoBodyStream { + #[pin] + pub body: Body + } +} + +impl Stream for IoBodyStream { + type Item = io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + let res = ready!(this.body.poll_next(cx)); + match res { + Some(Ok(b)) => Poll::Ready(Some(Ok(b))), + Some(Err(e)) => { + let io_err = match e + .source() + .and_then(|source| source.downcast_ref::()) + { + Some(io_err) => io::Error::from(io_err.kind()), + None => io::Error::new(io::ErrorKind::Other, e), + }; + Poll::Ready(Some(Err(io_err))) + } + None => Poll::Ready(None), + } + } +} + +// type IoBodyStreamReader = StreamReader; +// type FramedBody = FramedRead; + +pub struct LinesBytesCodec { + // Stored index of the next index to examine for a `\n` character. + // This is used to optimize searching. + // For example, if `decode` was called with `abc`, it would hold `3`, + // because that is the next index to examine. + // The next time `decode` is called with `abcde\n`, the method will + // only look at `de\n` before returning. + next_index: usize, + + /// The maximum length for a given line. If `usize::MAX`, lines will be + /// read until a `\n` character is reached. + max_length: usize, + + /// Are we currently discarding the remainder of a line which was over + /// the length limit? + is_discarding: bool, +} + +impl Default for LinesBytesCodec { + /// Returns a `LinesBytesCodec` for splitting up data into lines. + /// + /// # Note + /// + /// The returned `LinesBytesCodec` will not have an upper bound on the length + /// of a buffered line. See the documentation for [`new_with_max_length`] + /// for information on why this could be a potential security risk. + /// + /// [`new_with_max_length`]: crate::codec::LinesBytesCodec::default_with_max_length() + fn default() -> Self { + LinesBytesCodec { + next_index: 0, + max_length: usize::MAX, + is_discarding: false, + } + } +} + +impl Decoder for LinesBytesCodec { + type Item = BytesMut; + type Error = LinesCodecError; + + fn decode(&mut self, buf: &mut BytesMut) -> Result, LinesCodecError> { + loop { + // Determine how far into the buffer we'll search for a newline. If + // there's no max_length set, we'll read to the end of the buffer. + let read_to = std::cmp::min(self.max_length.saturating_add(1), buf.len()); + + let newline_offset = buf[self.next_index..read_to] + .iter() + .position(|b| *b == b'\n'); + + match (self.is_discarding, newline_offset) { + (true, Some(offset)) => { + // If we found a newline, discard up to that offset and + // then stop discarding. On the next iteration, we'll try + // to read a line normally. + buf.advance(offset + self.next_index + 1); + self.is_discarding = false; + self.next_index = 0; + } + (true, None) => { + // Otherwise, we didn't find a newline, so we'll discard + // everything we read. On the next iteration, we'll continue + // discarding up to max_len bytes unless we find a newline. + buf.advance(read_to); + self.next_index = 0; + if buf.is_empty() { + return Ok(None); + } + } + (false, Some(offset)) => { + // Found a line! + let newline_index = offset + self.next_index; + self.next_index = 0; + let mut line = buf.split_to(newline_index + 1); + line.truncate(line.len() - 1); + without_carriage_return(&mut line); + return Ok(Some(line)); + } + (false, None) if buf.len() > self.max_length => { + // Reached the maximum length without finding a + // newline, return an error and start discarding on the + // next call. + self.is_discarding = true; + return Err(LinesCodecError::MaxLineLengthExceeded); + } + (false, None) => { + // We didn't find a line or reach the length limit, so the next + // call will resume searching at the current offset. + self.next_index = read_to; + return Ok(None); + } + } + } + } + + fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, LinesCodecError> { + Ok(match self.decode(buf)? { + Some(frame) => Some(frame), + None => { + // No terminating newline - return remaining data, if any + if buf.is_empty() || buf == &b"\r"[..] { + None + } else { + let mut line = buf.split_to(buf.len()); + line.truncate(line.len() - 1); + without_carriage_return(&mut line); + self.next_index = 0; + Some(line) + } + } + }) + } +} + +fn without_carriage_return(s: &mut BytesMut) { + if let Some(&b'\r') = s.last() { + s.truncate(s.len() - 1); + } +} + +impl Encoder for LinesBytesCodec +where + T: AsRef<[u8]>, +{ + type Error = LinesCodecError; + + fn encode(&mut self, line: T, buf: &mut BytesMut) -> Result<(), LinesCodecError> { + let line = line.as_ref(); + buf.reserve(line.len() + 1); + buf.put(line); + buf.put_u8(b'\n'); + Ok(()) + } +} diff --git a/crates/corro-types/src/lib.rs b/crates/corro-types/src/lib.rs index 060fce58..778962fd 100644 --- a/crates/corro-types/src/lib.rs +++ b/crates/corro-types/src/lib.rs @@ -5,6 +5,7 @@ pub mod api; pub mod broadcast; pub mod change; pub mod config; +pub mod http; pub mod members; pub mod pubsub; pub mod schema; diff --git a/crates/corro-types/src/schema.rs b/crates/corro-types/src/schema.rs index 5b4745ba..b36f3d6d 100644 --- a/crates/corro-types/src/schema.rs +++ b/crates/corro-types/src/schema.rs @@ -677,14 +677,12 @@ fn prepare_table( .unwrap_or_else(|| { Ok(columns .iter() - .filter_map(|def| { - def.constraints - .iter() - .any(|named| { - matches!(named.constraint, ColumnConstraint::PrimaryKey { .. }) - }) - .then(|| def.col_name.0.clone()) + .filter(|&def| { + def.constraints.iter().any(|named| { + matches!(named.constraint, ColumnConstraint::PrimaryKey { .. }) + }) }) + .map(|def| def.col_name.0.clone()) .collect()) })?; From a9addcfbc5ae7d642139985acd0a3631c7585174 Mon Sep 17 00:00:00 2001 From: Jerome Gravel-Niquet Date: Mon, 23 Oct 2023 12:15:38 -0400 Subject: [PATCH 02/12] improvements --- Cargo.lock | 1 + crates/corro-pg/Cargo.toml | 3 +- crates/corro-pg/src/lib.rs | 713 ++++++++++++++++++++++++++++++------- 3 files changed, 589 insertions(+), 128 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e13e2df7..c70c59f1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -843,6 +843,7 @@ dependencies = [ "tokio-util", "tracing", "tracing-subscriber", + "tripwire", ] [[package]] diff --git a/crates/corro-pg/Cargo.toml b/crates/corro-pg/Cargo.toml index 528be5f8..36e5f878 100644 --- a/crates/corro-pg/Cargo.toml +++ b/crates/corro-pg/Cargo.toml @@ -23,4 +23,5 @@ postgres-types = { version = "0.2", features = ["with-time-0_3"] } tracing-subscriber = { workspace = true } tempfile = { workspace = true } corro-tests = { path = "../corro-tests" } -tokio-postgres = { version = "0.7.10" } \ No newline at end of file +tokio-postgres = { version = "0.7.10" } +tripwire = { path = "../tripwire" } \ No newline at end of file diff --git a/crates/corro-pg/src/lib.rs b/crates/corro-pg/src/lib.rs index b373c451..1d0385eb 100644 --- a/crates/corro-pg/src/lib.rs +++ b/crates/corro-pg/src/lib.rs @@ -2,24 +2,30 @@ pub mod proto; pub mod proto_ext; pub mod sql_state; -use std::{collections::HashMap, net::SocketAddr}; +use std::{collections::HashMap, net::SocketAddr, sync::Arc}; use compact_str::CompactString; use corro_types::{agent::Agent, change::SqliteValue, config::PgConfig}; use futures::{Sink, SinkExt, StreamExt}; use pgwire::{ - api::ClientInfoHolder, - error::{ErrorInfo, PgWireError}, + api::{ + results::{DataRowEncoder, FieldFormat, FieldInfo, Tag}, + ClientInfo, ClientInfoHolder, + }, + error::{ErrorInfo, PgWireError, PgWireResult}, messages::{ + data::{FieldDescription, ParameterDescription, RowDescription}, + extendedquery::{CloseComplete, ParseComplete}, response::{ReadyForQuery, READY_STATUS_IDLE}, PgWireBackendMessage, PgWireFrontendMessage, }, tokio::PgWireMessageServerCodec, }; -use rusqlite::Statement; -use tokio::net::TcpListener; +use postgres_types::{FromSql, Oid, Type}; +use rusqlite::{types::ValueRef, Statement}; +use tokio::{net::TcpListener, sync::mpsc::channel, task::block_in_place}; use tokio_util::codec::{Framed, FramedRead}; -use tracing::{debug, info}; +use tracing::{debug, info, warn}; use crate::{ proto::{Bind, ConnectionCodec, ProtocolError}, @@ -49,10 +55,12 @@ async fn start(agent: Agent, pg: PgConfig) -> Result { ); let msg = framed.next().await.unwrap()?; + println!("msg: {msg:?}"); match msg { PgWireFrontendMessage::Startup(startup) => { info!("received startup message: {startup:?}"); + println!("huh..."); } _ => { framed @@ -69,128 +77,576 @@ async fn start(agent: Agent, pg: PgConfig) -> Result { } } + framed.set_state(pgwire::api::PgWireConnectionState::ReadyForQuery); + framed - .send(PgWireBackendMessage::Authentication( + .feed(PgWireBackendMessage::Authentication( pgwire::messages::startup::Authentication::Ok, )) .await?; framed - .send(PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new( + .feed(PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new( READY_STATUS_IDLE, ))) .await?; - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - - std::thread::spawn(move || -> Result<(), BoxError> { - rt.block_on(async move { - let conn = rusqlite::Connection::open_in_memory().unwrap(); + framed.flush().await?; - let mut prepared: HashMap)> = HashMap::new(); + println!("sent auth ok and ReadyForQuery"); - let mut portals: HashMap = - HashMap::new(); + let (front_tx, mut front_rx) = channel(1024); + let (back_tx, mut back_rx) = channel(1024); - let mut row_cache: Vec = vec![]; + let (mut sink, mut stream) = framed.split(); - while let Some(decode_res) = framed.next().await { + tokio::spawn({ + let back_tx = back_tx.clone(); + async move { + while let Some(decode_res) = stream.next().await { + println!("decode_res: {decode_res:?}"); let msg = match decode_res { Ok(msg) => msg, Err(PgWireError::IoError(io_error)) => { debug!("postgres io error: {io_error}"); break; } - // Err(ProtocolError::ParserError) => { - // framed - // .send(proto::ErrorResponse::new( - // proto::SqlState::SyntaxError, - // proto::Severity::Error, - // "parsing error", - // )) - // .await?; - // continue; - // } Err(e) => { - framed - .send(PgWireBackendMessage::ErrorResponse( + // attempt to send this... + _ = back_tx.try_send(( + PgWireBackendMessage::ErrorResponse( ErrorInfo::new( "FATAL".to_owned(), "XX000".to_owned(), e.to_string(), ) .into(), - )) - .await?; + ), + true, + )); break; } }; - match msg { - PgWireFrontendMessage::Startup(_) => { - framed - .send(PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "FATAL".into(), - SqlState::PROTOCOL_VIOLATION.code().into(), - "unexpected startup message".into(), - ) - .into(), - )) - .await?; - continue; + front_tx.send(msg).await?; + } + + Ok::<_, BoxError>(()) + } + }); + + tokio::spawn(async move { + while let Some((back, flush)) = back_rx.recv().await { + println!("sending: {back:?}"); + sink.feed(back).await?; + if flush { + sink.flush().await?; + } + } + Ok::<_, std::io::Error>(()) + }); + + block_in_place(|| { + let conn = rusqlite::Connection::open_in_memory().unwrap(); + println!("opened in-memory conn"); + + let mut prepared: HashMap)> = + HashMap::new(); + + let mut portals: HashMap = + HashMap::new(); + + let mut row_cache: Vec = vec![]; + + 'outer: while let Some(msg) = front_rx.blocking_recv() { + println!("msg: {msg:?}"); + + match msg { + PgWireFrontendMessage::Startup(_) => { + back_tx.blocking_send(( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "FATAL".into(), + SqlState::PROTOCOL_VIOLATION.code().into(), + "unexpected startup message".into(), + ) + .into(), + ), + true, + ))?; + continue; + } + PgWireFrontendMessage::Parse(parse) => { + let prepped = match conn.prepare(parse.query()) { + Ok(prepped) => prepped, + Err(e) => { + back_tx.blocking_send(( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ))?; + continue; + } + }; + + prepared.insert( + parse.name().as_deref().unwrap_or("").into(), + ( + parse.query().clone(), + prepped, + parse + .type_oids() + .iter() + .filter_map(|oid| Type::from_oid(*oid)) + .collect(), + ), + ); + + back_tx.blocking_send(( + PgWireBackendMessage::ParseComplete(ParseComplete::new()), + true, + ))?; + } + PgWireFrontendMessage::Describe(desc) => { + let name = desc.name().as_deref().unwrap_or(""); + match desc.target_type() { + // statement + b'S' => { + if let Some((_, prepped, _)) = prepared.get(name) { + let mut oids = vec![]; + let mut fields = vec![]; + for col in prepped.columns() { + let col_type = match name_to_type( + col.decl_type().unwrap_or("text"), + ) { + Ok(t) => t, + Err(e) => { + back_tx.blocking_send(( + PgWireBackendMessage::ErrorResponse( + e.into(), + ), + true, + ))?; + continue 'outer; + } + }; + oids.push(col_type.oid()); + fields.push(FieldInfo::new( + col.name().to_string(), + None, + None, + col_type, + FieldFormat::Text, + )); + } + back_tx.blocking_send(( + PgWireBackendMessage::ParameterDescription( + ParameterDescription::new(oids), + ), + false, + ))?; + back_tx.blocking_send(( + PgWireBackendMessage::RowDescription( + RowDescription::new( + fields.iter().map(Into::into).collect(), + ), + ), + true, + ))?; + continue; + } + back_tx.blocking_send(( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".into(), + "statement not found".into(), + ) + .into(), + ), + true, + ))?; + } + // portal + b'P' => { + if let Some((_, prepped)) = portals.get(name) { + let mut oids = vec![]; + let mut fields = vec![]; + for col in prepped.columns() { + let col_type = match name_to_type( + col.decl_type().unwrap_or("text"), + ) { + Ok(t) => t, + Err(e) => { + back_tx.blocking_send(( + PgWireBackendMessage::ErrorResponse( + e.into(), + ), + true, + ))?; + continue 'outer; + } + }; + oids.push(col_type.oid()); + fields.push(FieldInfo::new( + col.name().to_string(), + None, + None, + col_type, + FieldFormat::Text, + )); + } + back_tx.blocking_send(( + PgWireBackendMessage::RowDescription( + RowDescription::new( + fields.iter().map(Into::into).collect(), + ), + ), + true, + ))?; + continue; + } + back_tx.blocking_send(( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".into(), + "portal not found".into(), + ) + .into(), + ), + true, + ))?; + } + _ => { + back_tx.blocking_send(( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "FATAL".into(), + SqlState::PROTOCOL_VIOLATION.code().into(), + "unexpected describe type".into(), + ) + .into(), + ), + true, + ))?; + continue; + } } - PgWireFrontendMessage::Parse(parse) => { - if let Err(e) = conn.prepare_cached(parse.query()) { - framed - .send(PgWireBackendMessage::ErrorResponse( + } + PgWireFrontendMessage::Bind(bind) => { + let portal_name = bind + .portal_name() + .as_deref() + .map(CompactString::from) + .unwrap_or_default(); + + let stmt_name = bind.statement_name().as_deref().unwrap_or(""); + + let (sql, _, param_types) = match prepared.get(stmt_name) { + None => { + back_tx.blocking_send(( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + "statement not found".into(), + ) + .into(), + ), + true, + ))?; + continue; + } + Some(stmt) => stmt, + }; + + let mut prepped = match conn.prepare(sql) { + Ok(prepped) => prepped, + Err(e) => { + back_tx.blocking_send(( + PgWireBackendMessage::ErrorResponse( ErrorInfo::new( "ERROR".to_owned(), "XX000".to_owned(), e.to_string(), ) .into(), - )) - .await?; + ), + true, + ))?; continue; } + }; + + for (i, param) in bind.parameters().iter().enumerate() { + let idx = i + 1; + let b = match param { + None => { + if let Err(e) = prepped + .raw_bind_parameter(idx, rusqlite::types::Null) + { + back_tx.blocking_send(( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ))?; + continue 'outer; + } + continue; + } + Some(b) => b, + }; + + match param_types.get(i) { + None => { + back_tx.blocking_send(( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + "missing parameter type".into(), + ) + .into(), + ), + true, + ))?; + continue 'outer; + } + Some(param_type) => match param_type { + &Type::BOOL => { + let value: bool = + FromSql::from_sql(param_type, b.as_ref())?; + prepped.raw_bind_parameter(idx, value)?; + } + &Type::INT2 => { + let value: i16 = + FromSql::from_sql(param_type, b.as_ref())?; + prepped.raw_bind_parameter(idx, value)?; + } + &Type::INT4 => { + let value: i32 = + FromSql::from_sql(param_type, b.as_ref())?; + prepped.raw_bind_parameter(idx, value)?; + } + &Type::INT8 => { + let value: i64 = + FromSql::from_sql(param_type, b.as_ref())?; + prepped.raw_bind_parameter(idx, value)?; + } + &Type::TEXT | &Type::VARCHAR => { + let value: &str = + FromSql::from_sql(param_type, b.as_ref())?; + prepped.raw_bind_parameter(idx, value)?; + } + &Type::FLOAT4 => { + let value: f32 = + FromSql::from_sql(param_type, b.as_ref())?; + prepped.raw_bind_parameter(idx, value)?; + } + &Type::FLOAT8 => { + let value: f64 = + FromSql::from_sql(param_type, b.as_ref())?; + prepped.raw_bind_parameter(idx, value)?; + } + &Type::BYTEA => { + let value: &[u8] = + FromSql::from_sql(param_type, b.as_ref())?; + prepped.raw_bind_parameter(idx, value)?; + } + t => { + warn!("unsupported type: {t:?}"); + back_tx.blocking_send(( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + format!( + "unsupported type {t} at index {i}" + ), + ) + .into(), + ), + true, + ))?; + continue 'outer; + } + }, + } + } - prepared.insert( - parse.name().as_deref().unwrap_or("").into(), - (parse.query().clone(), parse.type_oids().clone()), - ); + portals.insert(portal_name, (stmt_name.into(), prepped)); + } + PgWireFrontendMessage::Sync(_) => { + back_tx.blocking_send(( + PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new( + READY_STATUS_IDLE, + )), + true, + ))?; + } + PgWireFrontendMessage::Execute(_) => todo!(), + PgWireFrontendMessage::Query(query) => { + let mut prepped = match conn.prepare(query.query()) { + Ok(prepped) => prepped, + Err(e) => { + back_tx.blocking_send(( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ))?; + continue; + } + }; + + let mut fields = vec![]; + for col in prepped.columns() { + let col_type = + match name_to_type(col.decl_type().unwrap_or("text")) { + Ok(t) => t, + Err(e) => { + back_tx.blocking_send(( + PgWireBackendMessage::ErrorResponse(e.into()), + true, + ))?; + continue 'outer; + } + }; + fields.push(FieldInfo::new( + col.name().to_string(), + None, + None, + col_type, + FieldFormat::Text, + )); + } + + back_tx.blocking_send(( + PgWireBackendMessage::RowDescription(RowDescription::new( + fields.iter().map(Into::into).collect(), + )), + true, + ))?; + + let schema = Arc::new(fields); + + let mut rows = prepped.raw_query(); + let ncols = schema.len(); + + let mut count = 0; + while let Ok(Some(row)) = rows.next() { + count += 1; + let mut encoder = DataRowEncoder::new(schema.clone()); + for idx in 0..ncols { + let data = row.get_ref_unwrap::(idx); + match data { + ValueRef::Null => { + encoder.encode_field(&None::).unwrap() + } + ValueRef::Integer(i) => { + encoder.encode_field(&i).unwrap(); + } + ValueRef::Real(f) => { + encoder.encode_field(&f).unwrap(); + } + ValueRef::Text(t) => { + encoder + .encode_field( + &String::from_utf8_lossy(t).as_ref(), + ) + .unwrap(); + } + ValueRef::Blob(b) => { + encoder.encode_field(&b).unwrap(); + } + } + } } - PgWireFrontendMessage::Describe(_) => todo!(), - PgWireFrontendMessage::Bind(bind) => { - let portal_name = bind - .portal_name() - .as_deref() - .map(CompactString::from) - .unwrap_or_default(); - - let stmt_name = bind.statement_name().as_deref().unwrap_or(""); + + // TODO: figure out what kind of execution it is: SELECT, INSERT, etc. + back_tx.blocking_send(( + PgWireBackendMessage::CommandComplete( + Tag::new_for_query(count).into(), + ), + true, + ))?; + } + PgWireFrontendMessage::Terminate(_) => { + break; + } + + PgWireFrontendMessage::PasswordMessageFamily(_) => todo!(), + PgWireFrontendMessage::Close(close) => { + let name = close.name().as_deref().unwrap_or(""); + match close.target_type() { + // statement + b'S' => { + if let Some((_, prepped, _)) = prepared.remove(name) { + portals.retain(|_, (stmt_name, _)| { + stmt_name.as_str() != name + }); + back_tx.blocking_send(( + PgWireBackendMessage::CloseComplete( + CloseComplete::new(), + ), + true, + ))?; + continue; + } + // not finding a statement is not an error + } + // portal + b'P' => { + portals.remove(name); + back_tx.blocking_send(( + PgWireBackendMessage::CloseComplete( + CloseComplete::new(), + ), + true, + ))?; + } + _ => { + back_tx.blocking_send(( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "FATAL".into(), + SqlState::PROTOCOL_VIOLATION.code().into(), + "unexpected Close target_type".into(), + ) + .into(), + ), + true, + ))?; + continue; + } } - PgWireFrontendMessage::Sync(_) => todo!(), - PgWireFrontendMessage::Execute(_) => todo!(), - PgWireFrontendMessage::Query(_) => todo!(), - PgWireFrontendMessage::Terminate(_) => todo!(), - - PgWireFrontendMessage::PasswordMessageFamily(_) => todo!(), - PgWireFrontendMessage::Close(_) => todo!(), - PgWireFrontendMessage::Flush(_) => todo!(), - PgWireFrontendMessage::CopyData(_) => todo!(), - PgWireFrontendMessage::CopyFail(_) => todo!(), - PgWireFrontendMessage::CopyDone(_) => todo!(), } + PgWireFrontendMessage::Flush(_) => todo!(), + PgWireFrontendMessage::CopyData(_) => todo!(), + PgWireFrontendMessage::CopyFail(_) => todo!(), + PgWireFrontendMessage::CopyDone(_) => todo!(), } + } - Ok::<_, BoxError>(()) - }) - }) - .join() - .unwrap()?; + Ok::<_, BoxError>(()) + })?; Ok::<_, BoxError>(()) }); @@ -204,57 +660,60 @@ async fn start(agent: Agent, pg: PgConfig) -> Result { return Ok(PgServer { local_addr }); } -// #[cfg(test)] -// mod tests { -// use tokio_postgres::NoTls; - -// use super::*; - -// #[tokio::test] -// async fn test_pg() -> Result<(), BoxError> { -// _ = tracing_subscriber::fmt::try_init(); -// let server = TcpListener::bind("127.0.0.1:0").await?; -// let local_addr = server.local_addr()?; - -// let conn_str = format!( -// "host={} port={} user=testuser", -// local_addr.ip(), -// local_addr.port() -// ); - -// let client_task = tokio::spawn(async move { -// let (client, client_conn) = tokio_postgres::connect(&conn_str, NoTls).await?; -// println!("client is ready!"); -// tokio::spawn(client_conn); +fn name_to_type(name: &str) -> Result { + match name.to_uppercase().as_ref() { + "INT" => Ok(Type::INT8), + "VARCHAR" => Ok(Type::VARCHAR), + "TEXT" => Ok(Type::TEXT), + "BINARY" => Ok(Type::BYTEA), + "FLOAT" => Ok(Type::FLOAT8), + _ => Err(ErrorInfo::new( + "ERROR".to_owned(), + "42846".to_owned(), + format!("Unsupported data type: {name}"), + )), + } +} -// client.prepare("SELECT 1").await?; -// Ok::<_, BoxError>(()) -// }); +#[cfg(test)] +mod tests { + use corro_tests::launch_test_agent; + use tokio_postgres::NoTls; + use tripwire::Tripwire; -// let (conn, remote_addr) = server.accept().await?; -// println!("accepted a conn, addr: {remote_addr}"); + use super::*; -// let mut framed = Framed::new(conn, ConnectionCodec::new()); + #[tokio::test(flavor = "multi_thread")] + async fn test_pg() -> Result<(), BoxError> { + _ = tracing_subscriber::fmt::try_init(); + let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple(); -// let msg = framed.next().await.unwrap()?; -// println!("recv msg: {msg:?}"); + let ta = launch_test_agent(|builder| builder.build(), tripwire).await?; -// framed.send(proto::AuthenticationOk).await?; -// framed.send(proto::ReadyForQuery).await?; + let server = start( + ta.agent.clone(), + PgConfig { + bind_addr: "127.0.0.1:0".parse()?, + }, + ) + .await?; -// let msg = framed.next().await.unwrap()?; -// println!("recv msg: {msg:?}"); + let conn_str = format!( + "host={} port={} user=testuser", + server.local_addr.ip(), + server.local_addr.port() + ); -// let query = if let PgWireFrontendMessage::Parse(Parse { query, .. }) = msg { -// query -// } else { -// panic!("unexpected message"); -// }; + let (client, client_conn) = tokio_postgres::connect(&conn_str, NoTls).await?; + println!("client is ready!"); + tokio::spawn(client_conn); -// println!("query: {query}"); + let stmt = client.prepare("SELECT 1").await?; + println!("after prepare"); -// assert!(client_task.await?.is_ok()); + let rows = client.query(&stmt, &[]).await?; + println!("rows: {rows:?}"); -// Ok(()) -// } -// } + Ok(()) + } +} From 5f26baaf0e74891a32a9b36a11bd8cda910dc093 Mon Sep 17 00:00:00 2001 From: Jerome Gravel-Niquet Date: Wed, 25 Oct 2023 16:02:06 -0400 Subject: [PATCH 03/12] much progress, this is a little insane --- Cargo.lock | 3 + crates/corro-agent/Cargo.toml | 1 + crates/corro-agent/src/agent.rs | 13 +- crates/corro-agent/src/api/public/mod.rs | 3 +- crates/corro-pg/Cargo.toml | 2 + crates/corro-pg/src/lib.rs | 1971 +++++++++++++++++++--- crates/corro-types/src/agent.rs | 8 +- crates/corro-types/src/pubsub.rs | 25 +- crates/corro-types/src/schema.rs | 311 ++-- 9 files changed, 1907 insertions(+), 430 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 50beecf3..af7ef499 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -739,6 +739,7 @@ dependencies = [ "camino", "compact_str 0.7.0", "config", + "corro-pg", "corro-speedy", "corro-tests", "corro-types", @@ -830,11 +831,13 @@ dependencies = [ "compact_str 0.7.0", "corro-tests", "corro-types", + "fallible-iterator", "futures", "pgwire", "phf", "postgres-types", "rusqlite", + "sqlite3-parser", "sqlparser", "tempfile", "thiserror", diff --git a/crates/corro-agent/Cargo.toml b/crates/corro-agent/Cargo.toml index c4c4db20..76479149 100644 --- a/crates/corro-agent/Cargo.toml +++ b/crates/corro-agent/Cargo.toml @@ -53,6 +53,7 @@ tripwire = { path = "../tripwire" } trust-dns-resolver = { workspace = true } uhlc = { workspace = true } uuid = { workspace = true } +corro-pg = { path = "../corro-pg" } [dev-dependencies] corro-tests = { path = "../corro-tests" } diff --git a/crates/corro-agent/src/agent.rs b/crates/corro-agent/src/agent.rs index 3c01b01b..46861c5c 100644 --- a/crates/corro-agent/src/agent.rs +++ b/crates/corro-agent/src/agent.rs @@ -128,7 +128,9 @@ pub async fn setup(conf: Config, tripwire: Tripwire) -> eyre::Result<(Agent, Age let schema = { let mut conn = pool.write_priority().await?; migrate(&mut conn)?; - init_schema(&conn)? + let mut schema = init_schema(&conn)?; + schema.constrain()?; + schema }; { @@ -317,6 +319,15 @@ pub async fn run(agent: Agent, opts: AgentOptions) -> eyre::Result<()> { rtt_rx, } = opts; + if let Some(pg_conf) = agent.config().api.pg.clone() { + info!("Starting PostgreSQL wire-compatible server"); + let pg_server = corro_pg::start(agent.clone(), pg_conf).await?; + info!( + "Started PostgreSQL wire-compatible server, listening at {}", + pg_server.local_addr + ); + } + let mut matcher_id_cache = MatcherIdCache::default(); let mut matcher_bcast_cache = MatcherBroadcastCache::default(); diff --git a/crates/corro-agent/src/api/public/mod.rs b/crates/corro-agent/src/api/public/mod.rs index c8eafca6..9160c7c3 100644 --- a/crates/corro-agent/src/api/public/mod.rs +++ b/crates/corro-agent/src/api/public/mod.rs @@ -1139,7 +1139,8 @@ pub async fn api_v1_queries( async fn execute_schema(agent: &Agent, statements: Vec) -> eyre::Result<()> { let new_sql: String = statements.join(";"); - let partial_schema = parse_sql(&new_sql)?; + let mut partial_schema = parse_sql(&new_sql)?; + partial_schema.constrain()?; let mut conn = agent.pool().write_priority().await?; diff --git a/crates/corro-pg/Cargo.toml b/crates/corro-pg/Cargo.toml index 36e5f878..4b1253a7 100644 --- a/crates/corro-pg/Cargo.toml +++ b/crates/corro-pg/Cargo.toml @@ -18,6 +18,8 @@ tracing = { workspace = true } time = { workspace = true } phf = "*" postgres-types = { version = "0.2", features = ["with-time-0_3"] } +sqlite3-parser = { workspace = true } +fallible-iterator = { workspace = true } [dev-dependencies] tracing-subscriber = { workspace = true } diff --git a/crates/corro-pg/src/lib.rs b/crates/corro-pg/src/lib.rs index 1d0385eb..e481f58e 100644 --- a/crates/corro-pg/src/lib.rs +++ b/crates/corro-pg/src/lib.rs @@ -2,53 +2,234 @@ pub mod proto; pub mod proto_ext; pub mod sql_state; -use std::{collections::HashMap, net::SocketAddr, sync::Arc}; +use std::{borrow::Borrow, collections::HashMap, future::poll_fn, net::SocketAddr, sync::Arc}; +use bytes::{Buf, BytesMut}; use compact_str::CompactString; -use corro_types::{agent::Agent, change::SqliteValue, config::PgConfig}; -use futures::{Sink, SinkExt, StreamExt}; +use corro_types::{ + agent::{Agent, KnownDbVersion}, + broadcast::Timestamp, + config::PgConfig, + schema::{parse_sql, Schema, SchemaError, SqliteType, Table}, +}; +use fallible_iterator::FallibleIterator; +use futures::{SinkExt, StreamExt}; use pgwire::{ api::{ results::{DataRowEncoder, FieldFormat, FieldInfo, Tag}, ClientInfo, ClientInfoHolder, }, - error::{ErrorInfo, PgWireError, PgWireResult}, + error::{ErrorInfo, PgWireError}, messages::{ - data::{FieldDescription, ParameterDescription, RowDescription}, - extendedquery::{CloseComplete, ParseComplete}, - response::{ReadyForQuery, READY_STATUS_IDLE}, + data::{ParameterDescription, RowDescription}, + extendedquery::{BindComplete, CloseComplete, ParseComplete, PortalSuspended}, + response::{ + EmptyQueryResponse, ReadyForQuery, READY_STATUS_IDLE, READY_STATUS_TRANSACTION_BLOCK, + }, + startup::SslRequest, PgWireBackendMessage, PgWireFrontendMessage, }, tokio::PgWireMessageServerCodec, }; -use postgres_types::{FromSql, Oid, Type}; -use rusqlite::{types::ValueRef, Statement}; -use tokio::{net::TcpListener, sync::mpsc::channel, task::block_in_place}; -use tokio_util::codec::{Framed, FramedRead}; -use tracing::{debug, info, warn}; - -use crate::{ - proto::{Bind, ConnectionCodec, ProtocolError}, - sql_state::SqlState, +use postgres_types::{FromSql, Type}; +use rusqlite::{named_params, types::ValueRef, Connection, Statement}; +use sqlite3_parser::ast::{ + Cmd, CreateTableBody, Expr, Id, InsertBody, Literal, Name, OneSelect, ResultColumn, Select, + SelectTable, Stmt, +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt, ReadBuf}, + net::{TcpListener, TcpStream}, + sync::mpsc::channel, + task::block_in_place, }; +use tokio_util::codec::Framed; +use tracing::{debug, error, info, trace, warn}; + +use crate::sql_state::SqlState; type BoxError = Box; -struct PgServer { - local_addr: SocketAddr, +pub struct PgServer { + pub local_addr: SocketAddr, +} + +enum BackendResponse { + Message { + message: PgWireBackendMessage, + flush: bool, + }, + Flush, +} + +impl From<(PgWireBackendMessage, bool)> for BackendResponse { + fn from((message, flush): (PgWireBackendMessage, bool)) -> Self { + Self::Message { message, flush } + } +} + +#[derive(Clone, Debug)] +struct ParsedCmd(Cmd); + +impl ParsedCmd { + pub fn returns_rows_affected(&self) -> bool { + matches!( + self.0, + Cmd::Stmt(Stmt::Insert { .. }) + | Cmd::Stmt(Stmt::Update { .. }) + | Cmd::Stmt(Stmt::Delete { .. }) + ) + } + pub fn returns_num_rows(&self) -> bool { + matches!( + self.0, + Cmd::Stmt(Stmt::Select(_)) + | Cmd::Stmt(Stmt::CreateTable { + body: CreateTableBody::AsSelect(_), + .. + }) + ) + } + pub fn is_begin(&self) -> bool { + matches!(self.0, Cmd::Stmt(Stmt::Begin(_, _))) + } + pub fn is_commit(&self) -> bool { + matches!(self.0, Cmd::Stmt(Stmt::Commit(_))) + } + pub fn is_rollback(&self) -> bool { + matches!(self.0, Cmd::Stmt(Stmt::Rollback { .. })) + } + + fn tag(&self, rows: Option) -> Tag { + match &self.0 { + Cmd::Stmt(stmt) => match stmt { + Stmt::Select(_) + | Stmt::CreateTable { + body: CreateTableBody::AsSelect(_), + .. + } => Tag::new_for_query(rows.unwrap_or_default()), + Stmt::AlterTable(_, _) => Tag::new_for_execution("ALTER", rows), + Stmt::Analyze(_) => Tag::new_for_execution("ANALYZE", rows), + Stmt::Attach { .. } => Tag::new_for_execution("ATTACH", rows), + Stmt::Begin(_, _) => Tag::new_for_execution("BEGIN", rows), + Stmt::Commit(_) => Tag::new_for_execution("COMMIT", rows), + Stmt::CreateIndex { .. } + | Stmt::CreateTable { .. } + | Stmt::CreateTrigger { .. } + | Stmt::CreateView { .. } + | Stmt::CreateVirtualTable { .. } => Tag::new_for_execution("CREATE", rows), + Stmt::Delete { .. } => Tag::new_for_execution("DELETE", rows), + Stmt::Detach(_) => Tag::new_for_execution("DETACH", rows), + Stmt::DropIndex { .. } + | Stmt::DropTable { .. } + | Stmt::DropTrigger { .. } + | Stmt::DropView { .. } => Tag::new_for_execution("DROP", rows), + Stmt::Insert { .. } => Tag::new_for_execution("INSERT", rows), + Stmt::Pragma(_, _) => Tag::new_for_execution("PRAGMA", rows), + Stmt::Reindex { .. } => Tag::new_for_execution("REINDEX", rows), + Stmt::Release(_) => Tag::new_for_execution("RELEAE", rows), + Stmt::Rollback { .. } => Tag::new_for_execution("ROLLBACK", rows), + Stmt::Savepoint(_) => Tag::new_for_execution("SAVEPOINT", rows), + + Stmt::Update { .. } => Tag::new_for_execution("UPDATE", rows), + Stmt::Vacuum(_, _) => Tag::new_for_execution("VACUUM", rows), + }, + _ => Tag::new_for_execution("OK", rows), + } + } +} + +#[derive(Clone, Debug)] +struct Query { + cmds: Vec, +} + +impl Query { + fn len(&self) -> usize { + self.cmds.len() + } + + fn is_empty(&self) -> bool { + self.cmds.is_empty() + } +} + +fn parse_query(sql: &str) -> Result { + let mut cmds = vec![]; + + let mut parser = sqlite3_parser::lexer::sql::Parser::new(sql.as_bytes()); + loop { + match parser.next() { + Ok(Some(cmd)) => { + cmds.push(ParsedCmd(cmd)); + } + Ok(None) => { + break; + } + Err(e) => return Err(e), + } + } + + Ok(Query { cmds }) +} + +enum OpenTx { + Implicit, + Explicit, +} + +async fn peek_for_sslrequest( + tcp_socket: &mut TcpStream, + ssl_supported: bool, +) -> std::io::Result { + let mut ssl = false; + let mut buf = [0u8; SslRequest::BODY_SIZE]; + let mut buf = ReadBuf::new(&mut buf); + loop { + let size = poll_fn(|cx| tcp_socket.poll_peek(cx, &mut buf)).await?; + if size == 0 { + // the tcp_stream has ended + return Ok(false); + } + if size == SslRequest::BODY_SIZE { + let mut buf_ref = buf.filled(); + // skip first 4 bytes + buf_ref.get_i32(); + if buf_ref.get_i32() == SslRequest::BODY_MAGIC_NUMBER { + // the socket is sending sslrequest, read the first 8 bytes + // skip first 8 bytes + tcp_socket + .read_exact(&mut [0u8; SslRequest::BODY_SIZE]) + .await?; + // ssl configured + if ssl_supported { + ssl = true; + tcp_socket.write_all(b"S").await?; + } else { + tcp_socket.write_all(b"N").await?; + } + } + + return Ok(ssl); + } + } } -async fn start(agent: Agent, pg: PgConfig) -> Result { - let mut server = TcpListener::bind(pg.bind_addr).await?; +pub async fn start(agent: Agent, pg: PgConfig) -> std::io::Result { + let server = TcpListener::bind(pg.bind_addr).await?; let local_addr = server.local_addr()?; tokio::spawn(async move { loop { - let (conn, remote_addr) = server.accept().await?; + let (mut conn, remote_addr) = server.accept().await?; info!("accepted a conn, addr: {remote_addr}"); let agent = agent.clone(); tokio::spawn(async move { + conn.set_nodelay(true)?; + let ssl = peek_for_sslrequest(&mut conn, false).await?; + println!("SSL? {ssl}"); + let mut framed = Framed::new( conn, PgWireMessageServerCodec::new(ClientInfoHolder::new(remote_addr, false)), @@ -112,17 +293,20 @@ async fn start(agent: Agent, pg: PgConfig) -> Result { } Err(e) => { // attempt to send this... - _ = back_tx.try_send(( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "FATAL".to_owned(), - "XX000".to_owned(), - e.to_string(), - ) + _ = back_tx.try_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "FATAL".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) .into(), - ), - true, - )); + ); break; } }; @@ -135,61 +319,154 @@ async fn start(agent: Agent, pg: PgConfig) -> Result { }); tokio::spawn(async move { - while let Some((back, flush)) = back_rx.recv().await { - println!("sending: {back:?}"); - sink.feed(back).await?; - if flush { - sink.flush().await?; + while let Some(back) = back_rx.recv().await { + match back { + BackendResponse::Message { message, flush } => { + println!("sending: {message:?}"); + sink.feed(message).await?; + if flush { + sink.flush().await?; + } + } + BackendResponse::Flush => { + sink.flush().await?; + } } } Ok::<_, std::io::Error>(()) }); block_in_place(|| { - let conn = rusqlite::Connection::open_in_memory().unwrap(); - println!("opened in-memory conn"); + let conn = agent.pool().client_dedicated().unwrap(); + println!("opened connection"); + + let schema = match compute_schema(&conn) { + Ok(schema) => schema, + Err(e) => { + error!("could not parse schema: {e}"); + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "FATAL".into(), + "XX000".into(), + "could not parse database schema".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + return Ok(()); + } + }; - let mut prepared: HashMap)> = - HashMap::new(); + let mut prepared: HashMap< + CompactString, + (String, Option, Statement, Vec), + > = HashMap::new(); - let mut portals: HashMap = - HashMap::new(); + let mut portals: HashMap< + CompactString, + ( + CompactString, + Option, + Statement, + Vec, + ), + > = HashMap::new(); - let mut row_cache: Vec = vec![]; + let mut open_tx = None; 'outer: while let Some(msg) = front_rx.blocking_recv() { println!("msg: {msg:?}"); match msg { PgWireFrontendMessage::Startup(_) => { - back_tx.blocking_send(( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "FATAL".into(), - SqlState::PROTOCOL_VIOLATION.code().into(), - "unexpected startup message".into(), - ) + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "FATAL".into(), + SqlState::PROTOCOL_VIOLATION.code().into(), + "unexpected startup message".into(), + ) + .into(), + ), + true, + ) .into(), - ), - true, - ))?; + )?; continue; } PgWireFrontendMessage::Parse(parse) => { - let prepped = match conn.prepare(parse.query()) { - Ok(prepped) => prepped, + let mut query = match parse_query(parse.query()) { + Ok(query) => query, Err(e) => { - back_tx.blocking_send(( + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + }; + + if query.len() > 1 { + back_tx.blocking_send( + ( PgWireBackendMessage::ErrorResponse( ErrorInfo::new( "ERROR".to_owned(), - "XX000".to_owned(), - e.to_string(), + sql_state::SqlState::PROTOCOL_VIOLATION + .code() + .into(), + "only 1 command per Parse is allowed".into(), ) .into(), ), true, - ))?; + ) + .into(), + )?; + continue; + } + + let parsed_cmd = if query.is_empty() { + None + } else { + Some(query.cmds.remove(0)) + }; + + println!("parsed cmd: {parsed_cmd:?}"); + + let prepped = match conn.prepare(parse.query()) { + Ok(prepped) => prepped, + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; continue; } }; @@ -198,6 +475,7 @@ async fn start(agent: Agent, pg: PgConfig) -> Result { parse.name().as_deref().unwrap_or("").into(), ( parse.query().clone(), + parsed_cmd, prepped, parse .type_oids() @@ -207,35 +485,40 @@ async fn start(agent: Agent, pg: PgConfig) -> Result { ), ); - back_tx.blocking_send(( - PgWireBackendMessage::ParseComplete(ParseComplete::new()), - true, - ))?; + back_tx.blocking_send( + ( + PgWireBackendMessage::ParseComplete(ParseComplete::new()), + true, + ) + .into(), + )?; } PgWireFrontendMessage::Describe(desc) => { let name = desc.name().as_deref().unwrap_or(""); match desc.target_type() { // statement b'S' => { - if let Some((_, prepped, _)) = prepared.get(name) { + if let Some((_, cmd, prepped, param_types)) = + prepared.get(name) + { let mut oids = vec![]; let mut fields = vec![]; for col in prepped.columns() { - let col_type = match name_to_type( - col.decl_type().unwrap_or("text"), - ) { - Ok(t) => t, - Err(e) => { - back_tx.blocking_send(( + let col_type = + match name_to_type( + col.decl_type().unwrap_or("text"), + ) { + Ok(t) => t, + Err(e) => { + back_tx.blocking_send(( PgWireBackendMessage::ErrorResponse( e.into(), ), true, - ))?; - continue 'outer; - } - }; - oids.push(col_type.oid()); + ).into())?; + continue 'outer; + } + }; fields.push(FieldInfo::new( col.name().to_string(), None, @@ -244,97 +527,153 @@ async fn start(agent: Agent, pg: PgConfig) -> Result { FieldFormat::Text, )); } - back_tx.blocking_send(( - PgWireBackendMessage::ParameterDescription( - ParameterDescription::new(oids), - ), - false, - ))?; - back_tx.blocking_send(( - PgWireBackendMessage::RowDescription( - RowDescription::new( - fields.iter().map(Into::into).collect(), + + if param_types.len() != prepped.parameter_count() { + if let Some(ParsedCmd(Cmd::Stmt(stmt))) = &cmd { + let params = parameter_types(&schema, stmt); + println!("GOT PARAMS TO OVERRIDE: {params:?}"); + for param in params { + oids.push(match param { + SqliteType::Null => unreachable!(), + SqliteType::Integer => Type::INT8.oid(), + SqliteType::Real => Type::FLOAT8.oid(), + SqliteType::Text => Type::TEXT.oid(), + SqliteType::Blob => Type::BYTEA.oid(), + }) + } + } + } else { + for param in 0..prepped.parameter_count() { + // if let Some(t) = param_types.get(param) { + // oids.push(t.oid()); + // } + oids.push( + param_types + .get(param) + .map(|t| t.oid()) + // this should not happen... + .unwrap_or(Type::TEXT.oid()), + ); + } + } + + back_tx.blocking_send( + ( + PgWireBackendMessage::ParameterDescription( + ParameterDescription::new(oids), ), - ), - true, - ))?; + false, + ) + .into(), + )?; + back_tx.blocking_send( + ( + PgWireBackendMessage::RowDescription( + RowDescription::new( + fields.iter().map(Into::into).collect(), + ), + ), + false, + ) + .into(), + )?; continue; } - back_tx.blocking_send(( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".into(), - "XX000".into(), - "statement not found".into(), - ) + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".into(), + "statement not found".into(), + ) + .into(), + ), + true, + ) .into(), - ), - true, - ))?; + )?; } // portal b'P' => { - if let Some((_, prepped)) = portals.get(name) { + if let Some((_, _, prepped, result_formats)) = + portals.get(name) + { let mut oids = vec![]; let mut fields = vec![]; - for col in prepped.columns() { - let col_type = match name_to_type( - col.decl_type().unwrap_or("text"), - ) { - Ok(t) => t, - Err(e) => { - back_tx.blocking_send(( + for (i, col) in + prepped.columns().into_iter().enumerate() + { + let col_type = + match name_to_type( + col.decl_type().unwrap_or("text"), + ) { + Ok(t) => t, + Err(e) => { + back_tx.blocking_send(( PgWireBackendMessage::ErrorResponse( e.into(), ), true, - ))?; - continue 'outer; - } - }; + ).into())?; + continue 'outer; + } + }; oids.push(col_type.oid()); fields.push(FieldInfo::new( col.name().to_string(), None, None, col_type, - FieldFormat::Text, + result_formats + .get(i) + .copied() + .unwrap_or(FieldFormat::Text), )); } - back_tx.blocking_send(( - PgWireBackendMessage::RowDescription( - RowDescription::new( - fields.iter().map(Into::into).collect(), + back_tx.blocking_send( + ( + PgWireBackendMessage::RowDescription( + RowDescription::new( + fields.iter().map(Into::into).collect(), + ), ), - ), - true, - ))?; + true, + ) + .into(), + )?; continue; } - back_tx.blocking_send(( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".into(), - "XX000".into(), - "portal not found".into(), - ) + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".into(), + "portal not found".into(), + ) + .into(), + ), + true, + ) .into(), - ), - true, - ))?; + )?; } _ => { - back_tx.blocking_send(( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "FATAL".into(), - SqlState::PROTOCOL_VIOLATION.code().into(), - "unexpected describe type".into(), - ) + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "FATAL".into(), + SqlState::PROTOCOL_VIOLATION.code().into(), + "unexpected describe type".into(), + ) + .into(), + ), + true, + ) .into(), - ), - true, - ))?; + )?; continue; } } @@ -348,19 +687,22 @@ async fn start(agent: Agent, pg: PgConfig) -> Result { let stmt_name = bind.statement_name().as_deref().unwrap_or(""); - let (sql, _, param_types) = match prepared.get(stmt_name) { + let (sql, query, _, param_types) = match prepared.get(stmt_name) { None => { - back_tx.blocking_send(( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - "statement not found".into(), - ) + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + "statement not found".into(), + ) + .into(), + ), + true, + ) .into(), - ), - true, - ))?; + )?; continue; } Some(stmt) => stmt, @@ -369,21 +711,67 @@ async fn start(agent: Agent, pg: PgConfig) -> Result { let mut prepped = match conn.prepare(sql) { Ok(prepped) => prepped, Err(e) => { - back_tx.blocking_send(( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - e.to_string(), - ) + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) .into(), - ), - true, - ))?; + )?; continue; } }; + let param_types = if bind.parameters().len() != param_types.len() { + if let Some(ParsedCmd(Cmd::Stmt(stmt))) = &query { + let params = parameter_types(&schema, stmt); + println!("computed params: {params:?}",); + params + .iter() + .map(|param| match param { + SqliteType::Null => unreachable!(), + SqliteType::Integer => Type::INT8, + SqliteType::Real => Type::FLOAT8, + SqliteType::Text => Type::TEXT, + SqliteType::Blob => Type::BYTEA, + }) + .collect() + } else { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + "could not determine parameter type".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + } else { + param_types.clone() + }; + + println!("CMD: {query:?}"); + if bind.parameters().len() != param_types.len() { + if let Some(ParsedCmd(Cmd::Stmt(stmt))) = &query { + let params = parameter_types(&schema, stmt); + println!("computed params: {params:?}",); + } + } + for (i, param) in bind.parameters().iter().enumerate() { let idx = i + 1; let b = match param { @@ -391,17 +779,20 @@ async fn start(agent: Agent, pg: PgConfig) -> Result { if let Err(e) = prepped .raw_bind_parameter(idx, rusqlite::types::Null) { - back_tx.blocking_send(( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - e.to_string(), - ) + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) .into(), - ), - true, - ))?; + )?; continue 'outer; } continue; @@ -411,17 +802,20 @@ async fn start(agent: Agent, pg: PgConfig) -> Result { match param_types.get(i) { None => { - back_tx.blocking_send(( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - "missing parameter type".into(), - ) + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + "missing parameter type".into(), + ) + .into(), + ), + true, + ) .into(), - ), - true, - ))?; + )?; continue 'outer; } Some(param_type) => match param_type { @@ -467,65 +861,119 @@ async fn start(agent: Agent, pg: PgConfig) -> Result { } t => { warn!("unsupported type: {t:?}"); - back_tx.blocking_send(( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!( + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + format!( "unsupported type {t} at index {i}" ), - ) + ) + .into(), + ), + true, + ) .into(), - ), - true, - ))?; + )?; continue 'outer; } }, } } - portals.insert(portal_name, (stmt_name.into(), prepped)); + portals.insert( + portal_name, + ( + stmt_name.into(), + query.clone(), + prepped, + bind.result_column_format_codes() + .iter() + .copied() + .map(FieldFormat::from) + .collect(), + ), + ); + + back_tx.blocking_send( + ( + PgWireBackendMessage::BindComplete(BindComplete::new()), + false, + ) + .into(), + )?; } PgWireFrontendMessage::Sync(_) => { - back_tx.blocking_send(( - PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new( - READY_STATUS_IDLE, - )), - true, - ))?; + let ready_status = if open_tx.is_some() { + READY_STATUS_TRANSACTION_BLOCK + } else { + READY_STATUS_IDLE + }; + back_tx.blocking_send( + ( + PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new( + ready_status, + )), + true, + ) + .into(), + )?; } - PgWireFrontendMessage::Execute(_) => todo!(), - PgWireFrontendMessage::Query(query) => { - let mut prepped = match conn.prepare(query.query()) { - Ok(prepped) => prepped, - Err(e) => { - back_tx.blocking_send(( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - e.to_string(), + PgWireFrontendMessage::Execute(execute) => { + let name = execute.name().as_deref().unwrap_or(""); + let (parsed_cmd, prepped, result_formats) = + match portals.get_mut(name) { + Some((_, Some(parsed_cmd), prepped, result_formats)) => { + (parsed_cmd, prepped, result_formats) + } + Some((_, None, _, _)) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::EmptyQueryResponse( + EmptyQueryResponse::new(), + ), + false, ) - .into(), - ), - true, - ))?; - continue; - } - }; + .into(), + )?; + continue; + } + None => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".into(), + "portal not found".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + }; let mut fields = vec![]; - for col in prepped.columns() { + for (i, col) in prepped.columns().into_iter().enumerate() { let col_type = match name_to_type(col.decl_type().unwrap_or("text")) { Ok(t) => t, Err(e) => { - back_tx.blocking_send(( - PgWireBackendMessage::ErrorResponse(e.into()), - true, - ))?; + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + e.into(), + ), + true, + ) + .into(), + )?; continue 'outer; } }; @@ -534,24 +982,61 @@ async fn start(agent: Agent, pg: PgConfig) -> Result { None, None, col_type, - FieldFormat::Text, + result_formats.get(i).copied().unwrap_or(FieldFormat::Text), )); } - back_tx.blocking_send(( - PgWireBackendMessage::RowDescription(RowDescription::new( - fields.iter().map(Into::into).collect(), - )), - true, - ))?; - let schema = Arc::new(fields); let mut rows = prepped.raw_query(); let ncols = schema.len(); + let max_rows = *execute.max_rows(); + let max_rows = if max_rows == 0 { + usize::MAX + } else { + max_rows as usize + }; let mut count = 0; - while let Ok(Some(row)) = rows.next() { + + loop { + if count >= max_rows { + std::mem::forget(rows); + back_tx.blocking_send( + ( + PgWireBackendMessage::PortalSuspended( + PortalSuspended::new(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + let row = match rows.next() { + Ok(Some(row)) => row, + Ok(None) => { + println!("done w/ rows"); + break; + } + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + }; count += 1; let mut encoder = DataRowEncoder::new(schema.clone()); for idx in 0..ncols { @@ -578,36 +1063,403 @@ async fn start(agent: Agent, pg: PgConfig) -> Result { } } } + match encoder.finish() { + Ok(data_row) => { + back_tx.blocking_send( + (PgWireBackendMessage::DataRow(data_row), false) + .into(), + )?; + } + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + } } - // TODO: figure out what kind of execution it is: SELECT, INSERT, etc. - back_tx.blocking_send(( - PgWireBackendMessage::CommandComplete( - Tag::new_for_query(count).into(), - ), - true, - ))?; + let tag = if parsed_cmd.returns_num_rows() { + parsed_cmd.tag(Some(count)) + } else if parsed_cmd.returns_rows_affected() { + parsed_cmd.tag(Some(conn.changes() as usize)) + } else { + parsed_cmd.tag(None) + }; + + // done! + back_tx.blocking_send( + (PgWireBackendMessage::CommandComplete(tag.into()), true) + .into(), + )?; + } + PgWireFrontendMessage::Query(query) => { + let trimmed = query.query().trim_matches(';'); + + let parsed_query = match parse_query(trimmed) { + Ok(q) => q, + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + }; + + if parsed_query.is_empty() { + back_tx.blocking_send( + ( + PgWireBackendMessage::EmptyQueryResponse( + EmptyQueryResponse::new(), + ), + false, + ) + .into(), + )?; + + let ready_status = if open_tx.is_some() { + ReadyForQuery::new(READY_STATUS_TRANSACTION_BLOCK) + } else { + ReadyForQuery::new(READY_STATUS_IDLE) + }; + + back_tx.blocking_send( + (PgWireBackendMessage::ReadyForQuery(ready_status), true) + .into(), + )?; + continue; + } + + let mut cmd_iter = parsed_query.cmds.into_iter().peekable(); + + loop { + let cmd = match cmd_iter.next() { + None => break, + Some(cmd) => cmd, + }; + + // need to start an implicit transaction + if open_tx.is_none() && !cmd.is_begin() { + if let Err(e) = conn.execute_batch("BEGIN") { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + println!("started IMPLICIT tx"); + open_tx = Some(OpenTx::Implicit); + } + + // close the current implement tx first + if matches!(open_tx, Some(OpenTx::Implicit)) && cmd.is_begin() { + println!("committing IMPLICIT tx"); + open_tx = None; + + if let Err(e) = handle_commit(&agent, &conn, "COMMIT") { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + + continue 'outer; + } + println!("committed IMPLICIT tx"); + } + + let count = if cmd.is_commit() { + open_tx = None; + + if let Err(e) = + handle_commit(&agent, &conn, &cmd.0.to_string()) + { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + + continue 'outer; + } + None + } else { + let mut prepped = match conn.prepare(&cmd.0.to_string()) { + Ok(prepped) => prepped, + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + }; + + let mut fields = vec![]; + for col in prepped.columns() { + let col_type = match name_to_type( + col.decl_type().unwrap_or("text"), + ) { + Ok(t) => t, + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + e.into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + }; + fields.push(FieldInfo::new( + col.name().to_string(), + None, + None, + col_type, + FieldFormat::Text, + )); + } + + back_tx.blocking_send( + ( + PgWireBackendMessage::RowDescription( + RowDescription::new( + fields.iter().map(Into::into).collect(), + ), + ), + true, + ) + .into(), + )?; + + let schema = Arc::new(fields); + + let mut rows = prepped.raw_query(); + let ncols = schema.len(); + + let mut count = 0; + while let Ok(Some(row)) = rows.next() { + count += 1; + let mut encoder = DataRowEncoder::new(schema.clone()); + for idx in 0..ncols { + let data = row.get_ref_unwrap::(idx); + match data { + ValueRef::Null => { + encoder.encode_field(&None::).unwrap() + } + ValueRef::Integer(i) => { + encoder.encode_field(&i).unwrap(); + } + ValueRef::Real(f) => { + encoder.encode_field(&f).unwrap(); + } + ValueRef::Text(t) => { + encoder + .encode_field( + &String::from_utf8_lossy(t) + .as_ref(), + ) + .unwrap(); + } + ValueRef::Blob(b) => { + encoder.encode_field(&b).unwrap(); + } + } + } + match encoder.finish() { + Ok(data_row) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::DataRow(data_row), + false, + ) + .into(), + )?; + } + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + } + } + Some(count) + }; + + let tag = if cmd.returns_num_rows() { + cmd.tag(count) + } else if cmd.returns_rows_affected() { + cmd.tag(Some(conn.changes() as usize)) + } else { + cmd.tag(None) + }; + + back_tx.blocking_send( + (PgWireBackendMessage::CommandComplete(tag.into()), true) + .into(), + )?; + + if cmd.is_begin() { + println!("setting EXPLICIT tx"); + // explicit tx + open_tx = Some(OpenTx::Explicit) + } else if cmd.is_rollback() || cmd.is_commit() { + println!("clearing current open tx"); + // if this was a rollback, remove the current open tx + open_tx = None; + } + } + + // automatically commit an implicit tx + if matches!(open_tx, Some(OpenTx::Implicit)) { + println!("committing IMPLICIT tx"); + open_tx = None; + + if let Err(e) = handle_commit(&agent, &conn, "COMMIT") { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + println!("committed IMPLICIT tx"); + } + + let ready_status = if open_tx.is_some() { + ReadyForQuery::new(READY_STATUS_TRANSACTION_BLOCK) + } else { + ReadyForQuery::new(READY_STATUS_IDLE) + }; + + back_tx.blocking_send( + (PgWireBackendMessage::ReadyForQuery(ready_status), true) + .into(), + )?; } PgWireFrontendMessage::Terminate(_) => { break; } - PgWireFrontendMessage::PasswordMessageFamily(_) => todo!(), + PgWireFrontendMessage::PasswordMessageFamily(_) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".to_owned(), + "PasswordMessage is not implemented".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } PgWireFrontendMessage::Close(close) => { let name = close.name().as_deref().unwrap_or(""); match close.target_type() { // statement b'S' => { - if let Some((_, prepped, _)) = prepared.remove(name) { - portals.retain(|_, (stmt_name, _)| { + if prepared.remove(name).is_some() { + portals.retain(|_, (stmt_name, _, _, _)| { stmt_name.as_str() != name }); - back_tx.blocking_send(( - PgWireBackendMessage::CloseComplete( - CloseComplete::new(), - ), - true, - ))?; + back_tx.blocking_send( + ( + PgWireBackendMessage::CloseComplete( + CloseComplete::new(), + ), + true, + ) + .into(), + )?; continue; } // not finding a statement is not an error @@ -615,33 +1467,89 @@ async fn start(agent: Agent, pg: PgConfig) -> Result { // portal b'P' => { portals.remove(name); - back_tx.blocking_send(( - PgWireBackendMessage::CloseComplete( - CloseComplete::new(), - ), - true, - ))?; + back_tx.blocking_send( + ( + PgWireBackendMessage::CloseComplete( + CloseComplete::new(), + ), + true, + ) + .into(), + )?; } _ => { - back_tx.blocking_send(( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "FATAL".into(), - SqlState::PROTOCOL_VIOLATION.code().into(), - "unexpected Close target_type".into(), - ) + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "FATAL".into(), + SqlState::PROTOCOL_VIOLATION.code().into(), + "unexpected Close target_type".into(), + ) + .into(), + ), + true, + ) .into(), - ), - true, - ))?; + )?; continue; } } } - PgWireFrontendMessage::Flush(_) => todo!(), - PgWireFrontendMessage::CopyData(_) => todo!(), - PgWireFrontendMessage::CopyFail(_) => todo!(), - PgWireFrontendMessage::CopyDone(_) => todo!(), + PgWireFrontendMessage::Flush(_) => { + back_tx.blocking_send(BackendResponse::Flush)?; + } + PgWireFrontendMessage::CopyData(_) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".to_owned(), + "CopyData is not implemented".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + PgWireFrontendMessage::CopyFail(_) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".to_owned(), + "CopyFail is not implemented".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + PgWireFrontendMessage::CopyDone(_) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".to_owned(), + "CopyDone is not implemented".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } } } @@ -662,10 +1570,11 @@ async fn start(agent: Agent, pg: PgConfig) -> Result { fn name_to_type(name: &str) -> Result { match name.to_uppercase().as_ref() { - "INT" => Ok(Type::INT8), + "ANY" => Ok(Type::ANY), + "INT" | "INTEGER" => Ok(Type::INT8), "VARCHAR" => Ok(Type::VARCHAR), "TEXT" => Ok(Type::TEXT), - "BINARY" => Ok(Type::BYTEA), + "BINARY" | "BLOB" => Ok(Type::BYTEA), "FLOAT" => Ok(Type::FLOAT8), _ => Err(ErrorInfo::new( "ERROR".to_owned(), @@ -675,9 +1584,463 @@ fn name_to_type(name: &str) -> Result { } } +fn handle_commit(agent: &Agent, conn: &Connection, commit_stmt: &str) -> rusqlite::Result<()> { + let actor_id = agent.actor_id(); + + let ts = Timestamp::from(agent.clock().new_timestamp()); + + let db_version: i64 = conn + .prepare_cached("SELECT crsql_next_db_version()")? + .query_row((), |row| row.get(0))?; + + let has_changes: bool = conn + .prepare_cached( + "SELECT EXISTS(SELECT 1 FROM crsql_changes WHERE site_id IS NULL AND db_version = ?);", + )? + .query_row([db_version], |row| row.get(0))?; + + if !has_changes { + conn.execute_batch(commit_stmt)?; + return Ok(()); + } + + let booked = { + agent + .bookie() + .blocking_write("handle_write_tx(for_actor)") + .for_actor(actor_id) + }; + + let last_seq: i64 = conn + .prepare_cached( + "SELECT MAX(seq) FROM crsql_changes WHERE site_id IS NULL AND db_version = ?", + )? + .query_row([db_version], |row| row.get(0))?; + + let mut book_writer = booked.blocking_write("handle_write_tx(book_writer)"); + + let last_version = book_writer.last().unwrap_or_default(); + trace!("last_version: {last_version}"); + let version = last_version + 1; + trace!("version: {version}"); + + conn.prepare_cached( + r#" + INSERT INTO __corro_bookkeeping (actor_id, start_version, db_version, last_seq, ts) + VALUES (:actor_id, :start_version, :db_version, :last_seq, :ts); + "#, + )? + .execute(named_params! { + ":actor_id": actor_id, + ":start_version": version, + ":db_version": db_version, + ":last_seq": last_seq, + ":ts": ts + })?; + + debug!(%actor_id, %version, %db_version, "inserted local bookkeeping row!"); + + conn.execute_batch(commit_stmt)?; + + trace!("committed tx, db_version: {db_version}, last_seq: {last_seq:?}"); + + book_writer.insert( + version, + KnownDbVersion::Current { + db_version, + last_seq, + ts, + }, + ); + + Ok(()) +} + +fn compute_schema(conn: &Connection) -> Result { + let mut dump = String::new(); + + let tables: HashMap = conn + .prepare(r#"SELECT name, sql FROM sqlite_schema WHERE type = "table" AND name IS NOT NULL AND sql IS NOT NULL ORDER BY tbl_name"#)? + .query_map((), |row| { + Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?)) + })? + .collect::>()?; + + for sql in tables.values() { + dump.push_str(sql.as_str()); + dump.push(';'); + } + + let indexes: HashMap = conn + .prepare(r#"SELECT name, sql FROM sqlite_schema WHERE type = "index" AND name IS NOT NULL AND sql IS NOT NULL ORDER BY tbl_name"#)? + .query_map((), |row| { + Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?)) + })? + .collect::>()?; + + for sql in indexes.values() { + dump.push_str(sql.as_str()); + dump.push(';'); + } + + parse_sql(dump.as_str()) +} + +fn is_param(expr: &Expr) -> bool { + matches!(expr, Expr::Variable(_)) +} + +enum SqliteNameRef<'a> { + Id(&'a Id), + Name(&'a Name), + Qualified(&'a Name, &'a Name), + DoublyQualified(&'a Name, &'a Name, &'a Name), +} + +impl<'a> SqliteNameRef<'a> { + fn to_owned(&self) -> SqliteName { + match self { + SqliteNameRef::Id(id) => SqliteName::Id((*id).clone()), + SqliteNameRef::Name(name) => SqliteName::Name((*name).clone()), + SqliteNameRef::Qualified(n0, n1) => SqliteName::Qualified((*n0).clone(), (*n1).clone()), + SqliteNameRef::DoublyQualified(n0, n1, n2) => { + SqliteName::DoublyQualified((*n0).clone(), (*n1).clone(), (*n2).clone()) + } + } + } +} + +#[derive(Clone, Debug)] +enum SqliteName { + Id(Id), + Name(Name), + Qualified(Name, Name), + DoublyQualified(Name, Name, Name), +} + +enum ParamType { + Type(SqliteType), + FromColumn(SqliteName), +} + +fn expr_to_name(expr: &Expr) -> Option { + match expr { + Expr::Id(id) => Some(SqliteNameRef::Id(id)), + Expr::Name(name) => Some(SqliteNameRef::Name(name)), + Expr::Qualified(n0, n1) => Some(SqliteNameRef::Qualified(n0, n1)), + Expr::DoublyQualified(n0, n1, n2) => Some(SqliteNameRef::DoublyQualified(n0, n1, n2)), + _ => None, + } +} + +fn literal_type(expr: &Expr) -> Option { + match expr { + Expr::Literal(lit) => match lit { + Literal::Numeric(num) => { + if num.parse::().is_ok() { + Some(SqliteType::Integer) + } else if num.parse::().is_ok() { + Some(SqliteType::Real) + } else { + // this should be unreachable... + None + } + } + Literal::String(_) => Some(SqliteType::Text), + Literal::Blob(_) => Some(SqliteType::Blob), + Literal::Keyword(keyword) => { + // TODO: figure out what this is... + warn!("got a keyword: {keyword}"); + None + } + Literal::Null => Some(SqliteType::Null), + Literal::CurrentDate | Literal::CurrentTime | Literal::CurrentTimestamp => { + // TODO: make this configurable at connection time or something + Some(SqliteType::Text) + } + }, + _ => None, + } +} + +fn handle_lhs_rhs(lhs: &Expr, rhs: &Expr) -> Option { + match ( + (expr_to_name(lhs), is_param(lhs)), + (expr_to_name(rhs), is_param(rhs)), + ) { + ((Some(name), _), (_, true)) | ((_, true), (Some(name), _)) => Some(name.to_owned()), + _ => None, + } +} + +// fn handle_select(select: &Select, params: &mut Vec) { +// let mut aliases = HashMap::new(); +// match &select.body.select { +// OneSelect::Select { +// columns, +// from, +// where_clause, +// .. +// } => { +// // process FROM for aliases only + +// if let Some(from) = from { +// if let Some(select) = from.select.as_deref() { +// match select { +// SelectTable::Table(name, alias, _) => {} +// SelectTable::TableCall(name, _, alias) => {} +// SelectTable::Select(_, alias) => {} +// SelectTable::Sub(_, _) => {} +// } +// } +// } + +// for col in columns.iter() { +// match col { +// ResultColumn::Expr(expr, _) => { +// if is_param(expr) { +// params.push(ParamType::Type(SqliteType::Text)); +// } else { +// extract_param(expr, &aliases, params); +// } +// } +// _ => { +// // nothing to do here I don't think +// } +// } +// } + +// if let Some(from) = from { +// if let Some(select) = from.select.as_deref() { +// match select { +// SelectTable::Table(_, _, _) => {} +// SelectTable::TableCall(_, _, _) => {} +// SelectTable::Select(_, _) => {} +// SelectTable::Sub(_, _) => {} +// } +// } +// } +// } +// OneSelect::Values(_) => { +// // TODO: handle values +// } +// } +// } + +fn extract_param(expr: &Expr, tables: &HashMap, params: &mut Vec) { + match expr { + Expr::Between { + lhs, start, end, .. + } => {} + Expr::Binary(lhs, _, rhs) => { + if let Some(name) = handle_lhs_rhs(lhs, rhs) { + println!("HANDLED LHS RHS: {name:?}"); + match name { + // not aliased! + SqliteName::Id(id) => { + // find the first one to match + for (_, table) in tables.iter() { + if let Some(col) = table.columns.get(&id.0) { + params.push(col.sql_type); + break; + } + } + } + SqliteName::Name(_) => {} + SqliteName::Qualified(tbl_name, col_name) + | SqliteName::DoublyQualified(_, tbl_name, col_name) => { + if let Some(table) = tables.get(&tbl_name.0) { + if let Some(col) = table.columns.get(&col_name.0) { + params.push(col.sql_type); + } + } + } + } + } + } + Expr::Case { + base, + when_then_pairs, + else_expr, + } => {} + Expr::Cast { expr, type_name } => {} + Expr::Collate(_, _) => {} + Expr::DoublyQualified(_, _, _) => {} + Expr::Exists(_) => {} + Expr::FunctionCall { + name, + distinctness, + args, + filter_over, + } => {} + Expr::FunctionCallStar { name, filter_over } => {} + Expr::Id(_) => {} + Expr::InList { lhs, not, rhs } => {} + Expr::InSelect { lhs, not, rhs } => {} + Expr::InTable { + lhs, + not, + rhs, + args, + } => {} + Expr::IsNull(_) => {} + Expr::Like { + lhs, + not, + op, + rhs, + escape, + } => {} + Expr::Literal(_) => {} + Expr::Name(_) => {} + Expr::NotNull(_) => {} + Expr::Parenthesized(_) => {} + Expr::Qualified(_, _) => {} + Expr::Raise(_, _) => {} + Expr::Subquery(_) => {} + Expr::Unary(_, _) => {} + Expr::Variable(_) => {} + } +} + +fn parameter_types(schema: &Schema, stmt: &Stmt) -> Vec { + let mut params = vec![]; + + match stmt { + Stmt::Select(select) => match &select.body.select { + OneSelect::Select { + columns, + from, + where_clause, + .. + } => { + let mut tables: HashMap = HashMap::new(); + if let Some(from) = from { + if let Some(select) = from.select.as_deref() { + match select { + SelectTable::Table(qname, maybe_alias, _) => { + if let Some(alias) = maybe_alias { + let alias = match alias { + sqlite3_parser::ast::As::As(name) => name.0.clone(), + sqlite3_parser::ast::As::Elided(name) => name.0.clone(), + }; + if let Some(table) = schema.tables.get(&qname.name.0) { + tables.insert(alias, table); + } + } else { + // not aliased + if let Some(table) = schema.tables.get(&qname.name.0) { + tables.insert(qname.name.0.clone(), table); + } + } + } + SelectTable::TableCall(_, _, _) => {} + SelectTable::Select(_, _) => {} + SelectTable::Sub(_, _) => {} + } + } + if let Some(joins) = &from.joins { + for join in joins.iter() { + match &join.table { + SelectTable::Table(qname, maybe_alias, _) => { + if let Some(alias) = maybe_alias { + let alias = match alias { + sqlite3_parser::ast::As::As(name) => name.0.clone(), + sqlite3_parser::ast::As::Elided(name) => name.0.clone(), + }; + if let Some(table) = schema.tables.get(&qname.name.0) { + tables.insert(alias, table); + } + } else { + // not aliased + if let Some(table) = schema.tables.get(&qname.name.0) { + tables.insert(qname.name.0.clone(), table); + } + } + } + SelectTable::TableCall(_, _, _) => {} + SelectTable::Select(_, _) => {} + SelectTable::Sub(_, _) => {} + } + } + } + } + if let Some(where_clause) = where_clause { + println!("WHERE CLAUSE: {where_clause:?}"); + extract_param(where_clause, &tables, &mut params); + } + } + OneSelect::Values(_) => { + // TODO: handle this somehow... + } + }, + Stmt::Delete { + tbl_name, + where_clause: Some(where_clause), + .. + } => {} + Stmt::Insert { + tbl_name, + columns, + body, + .. + } => { + println!("GOT AN INSERT TO {tbl_name:?} on columns: {columns:?} w/ body: {body:?}"); + if let Some(table) = schema.tables.get(&tbl_name.name.0) { + match body { + InsertBody::Select(select, _) => match &select.body.select { + OneSelect::Select { + distinctness, + columns, + from, + where_clause, + group_by, + window_clause, + } => { + // handle this at some point... like any other select! + } + OneSelect::Values(values_values) => { + for values in values_values.iter() { + for (i, expr) in values.iter().enumerate() { + if is_param(expr) { + if let Some((name, col)) = table.columns.get_index(i) { + params.push(col.sql_type); + } + } + } + } + } + }, + InsertBody::DefaultValues => { + // nothing to do! + } + } + } + } + Stmt::Update { + with, + or_conflict, + tbl_name, + indexed, + sets, + from, + where_clause, + returning, + order_by, + limit, + } => {} + _ => { + // do nothing, there can't be bound params here! + } + } + + params +} + #[cfg(test)] mod tests { use corro_tests::launch_test_agent; + use postgres_types::ToSql; use tokio_postgres::NoTls; use tripwire::Tripwire; @@ -704,15 +2067,85 @@ mod tests { server.local_addr.port() ); - let (client, client_conn) = tokio_postgres::connect(&conn_str, NoTls).await?; + let (mut client, client_conn) = tokio_postgres::connect(&conn_str, NoTls).await?; + // let (mut client, client_conn) = + // tokio_postgres::connect("host=localhost port=5432 user=jerome", NoTls).await?; println!("client is ready!"); tokio::spawn(client_conn); + println!("before prepare"); let stmt = client.prepare("SELECT 1").await?; - println!("after prepare"); + println!( + "after prepare: params: {:?}, columns: {:?}", + stmt.params(), + stmt.columns() + ); + println!("before query"); let rows = client.query(&stmt, &[]).await?; - println!("rows: {rows:?}"); + + println!("rows count: {}", rows.len()); + for row in rows { + println!("ROW!!! {row:?}"); + } + + println!("before execute"); + let affected = client + .execute("INSERT INTO tests VALUES (1,2)", &[]) + .await?; + println!("after execute, affected: {affected}"); + + let row = client.query_one("SELECT * FROM crsql_changes", &[]).await?; + println!("CHANGE ROW: {row:?}"); + + client + .batch_execute("SELECT 1; SELECT 2; SELECT 3;") + .await?; + println!("after batch exec"); + + client.batch_execute("SELECT 1; BEGIN; SELECT 3;").await?; + println!("after batch exec 2"); + + client.batch_execute("SELECT 3; COMMIT; SELECT 3;").await?; + println!("after batch exec 3"); + + let tx = client.transaction().await?; + println!("after begin I assume"); + let res = tx + .execute( + "INSERT INTO tests VALUES ($1, $2)", + &[&2i64, &"hello world"], + ) + .await?; + println!("res (rows affected): {res}"); + let res = tx + .execute( + "INSERT INTO tests2 VALUES ($1, $2)", + &[&2i64, &"hello world 2"], + ) + .await?; + println!("res (rows affected): {res}"); + tx.commit().await?; + println!("after commit"); + + let row = client + .query_one("SELECT * FROM tests t WHERE t.id = ?", &[&2i64]) + .await?; + println!("ROW: {row:?}"); + + let row = client + .query_one("SELECT * FROM tests t WHERE t.id = ?", &[&2i64]) + .await?; + println!("ROW: {row:?}"); + + let row = client + .query_one("SELECT t.id, t.text, t2.text as t2text FROM tests t LEFT JOIN tests2 t2 WHERE t.id = ?", &[&2i64]) + .await?; + println!("ROW: {row:?}"); + + println!("t.id: {:?}", row.try_get::<_, i64>(0)); + println!("t.text: {:?}", row.try_get::<_, String>(1)); + println!("t2text: {:?}", row.try_get::<_, String>(2)); Ok(()) } diff --git a/crates/corro-types/src/agent.rs b/crates/corro-types/src/agent.rs index dc4192cf..fd78ea75 100644 --- a/crates/corro-types/src/agent.rs +++ b/crates/corro-types/src/agent.rs @@ -40,7 +40,7 @@ use crate::{ broadcast::{BroadcastInput, ChangeSource, ChangeV1, FocaInput, Timestamp}, config::Config, pubsub::MatcherHandle, - schema::NormalizedSchema, + schema::Schema, sqlite::{rusqlite_to_crsqlite, setup_conn, AttachMap, CrConn, SqlitePool, SqlitePoolError}, }; @@ -67,7 +67,7 @@ pub struct AgentConfig { pub tx_changes: Sender<(ChangeV1, ChangeSource)>, pub tx_foca: Sender, - pub schema: RwLock, + pub schema: RwLock, pub tripwire: Tripwire, } @@ -86,7 +86,7 @@ pub struct AgentInner { tx_empty: Sender<(ActorId, RangeInclusive)>, tx_changes: Sender<(ChangeV1, ChangeSource)>, tx_foca: Sender, - schema: RwLock, + schema: RwLock, limits: Limits, } @@ -171,7 +171,7 @@ impl Agent { &self.0.members } - pub fn schema(&self) -> &RwLock { + pub fn schema(&self) -> &RwLock { &self.0.schema } diff --git a/crates/corro-types/src/pubsub.rs b/crates/corro-types/src/pubsub.rs index 04b20fe3..176ddec6 100644 --- a/crates/corro-types/src/pubsub.rs +++ b/crates/corro-types/src/pubsub.rs @@ -25,7 +25,7 @@ use uuid::Uuid; use crate::{ api::QueryEvent, - schema::{NormalizedSchema, NormalizedTable}, + schema::{Schema, Table}, sqlite::Migration, }; @@ -356,7 +356,7 @@ const CHANGE_TYPE_COL: &str = "type"; impl Matcher { fn new( id: Uuid, - schema: &NormalizedSchema, + schema: &Schema, conn: &Connection, evt_tx: mpsc::Sender, sql: &str, @@ -530,7 +530,7 @@ impl Matcher { pub fn restore( id: Uuid, - schema: &NormalizedSchema, + schema: &Schema, conn: Connection, evt_tx: mpsc::Sender, sql: &str, @@ -544,7 +544,7 @@ impl Matcher { pub fn create( id: Uuid, - schema: &NormalizedSchema, + schema: &Schema, mut conn: Connection, evt_tx: mpsc::Sender, sql: &str, @@ -1040,10 +1040,7 @@ pub struct ParsedSelect { children: Vec, } -fn extract_select_columns( - select: &Select, - schema: &NormalizedSchema, -) -> Result { +fn extract_select_columns(select: &Select, schema: &Schema) -> Result { let mut parsed = ParsedSelect::default(); if let OneSelect::Select { @@ -1140,7 +1137,7 @@ fn extract_select_columns( fn extract_expr_columns( expr: &Expr, - schema: &NormalizedSchema, + schema: &Schema, parsed: &mut ParsedSelect, ) -> Result<(), MatcherError> { match expr { @@ -1318,7 +1315,7 @@ fn extract_expr_columns( fn extract_columns( columns: &[ResultColumn], from: Option<&Name>, - schema: &NormalizedSchema, + schema: &Schema, parsed: &mut ParsedSelect, ) -> Result<(), MatcherError> { let mut i = 0; @@ -1382,7 +1379,7 @@ fn extract_columns( fn table_to_expr( aliases: &HashMap, - tbl: &NormalizedTable, + tbl: &Table, table: &str, id: Uuid, ) -> Result { @@ -1521,7 +1518,7 @@ mod tests { { let tx = conn.transaction()?; - apply_schema(&tx, &NormalizedSchema::default(), &mut schema)?; + apply_schema(&tx, &Schema::default(), &mut schema)?; tx.commit()?; } @@ -1653,7 +1650,7 @@ mod tests { { let tx = conn.transaction().unwrap(); - apply_schema(&tx, &NormalizedSchema::default(), &mut schema).unwrap(); + apply_schema(&tx, &Schema::default(), &mut schema).unwrap(); tx.commit().unwrap(); } @@ -1695,7 +1692,7 @@ mod tests { { let tx = conn2.transaction().unwrap(); - apply_schema(&tx, &NormalizedSchema::default(), &mut schema).unwrap(); + apply_schema(&tx, &Schema::default(), &mut schema).unwrap(); tx.commit().unwrap(); } diff --git a/crates/corro-types/src/schema.rs b/crates/corro-types/src/schema.rs index b36f3d6d..e241c2fa 100644 --- a/crates/corro-types/src/schema.rs +++ b/crates/corro-types/src/schema.rs @@ -15,7 +15,7 @@ use sqlite3_parser::ast::{ use tracing::{debug, info}; #[derive(Debug, Clone, Eq, PartialEq)] -pub struct NormalizedColumn { +pub struct Column { pub name: String, pub sql_type: SqliteType, pub nullable: bool, @@ -25,7 +25,7 @@ pub struct NormalizedColumn { pub raw: ColumnDefinition, } -impl std::hash::Hash for NormalizedColumn { +impl std::hash::Hash for Column { fn hash(&self, state: &mut H) { self.name.hash(state); self.sql_type.hash(state); @@ -36,7 +36,7 @@ impl std::hash::Hash for NormalizedColumn { } } -impl fmt::Display for NormalizedColumn { +impl fmt::Display for Column { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.raw.to_fmt(f) } @@ -44,7 +44,7 @@ impl fmt::Display for NormalizedColumn { /// SQLite data types. /// See [Fundamental Datatypes](https://sqlite.org/c3ref/c_blob.html). -#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum SqliteType { /// NULL @@ -60,15 +60,15 @@ pub enum SqliteType { } #[derive(Debug, Clone)] -pub struct NormalizedTable { +pub struct Table { pub name: String, pub pk: IndexSet, - pub columns: IndexMap, - pub indexes: IndexMap, + pub columns: IndexMap, + pub indexes: IndexMap, pub raw: CreateTableBody, } -impl fmt::Display for NormalizedTable { +impl fmt::Display for Table { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { Cmd::Stmt(Stmt::CreateTable { temporary: false, @@ -81,16 +81,78 @@ impl fmt::Display for NormalizedTable { } #[derive(Debug, Clone, Eq, PartialEq)] -pub struct NormalizedIndex { +pub struct Index { pub name: String, pub tbl_name: String, pub columns: Vec, pub where_clause: Option, + pub unique: bool, } #[derive(Debug, Clone, Default)] -pub struct NormalizedSchema { - pub tables: IndexMap, +pub struct Schema { + pub tables: IndexMap, +} + +impl Schema { + pub fn constrain(&mut self) -> Result<(), ConstrainedSchemaError> { + self.tables.retain(|name, table| { + !(name.contains("crsql") && name.contains("sqlite") && name.starts_with("__corro")) + }); + + for (tbl_name, table) in self.tables.iter() { + // this should always be the case... + if let CreateTableBody::ColumnsAndConstraints { + columns, + constraints, + options, + } = &table.raw + { + if let Some(constraints) = constraints { + for named in constraints.iter() { + if let TableConstraint::PrimaryKey { columns, .. } = &named.constraint { + for column in columns.iter() { + if !matches!(column.expr, Expr::Id(_)) { + return Err(ConstrainedSchemaError::PrimaryKeyExpr); + } + } + } + } + } + } else { + // error here! + } + + for (name, column) in table.columns.iter() { + if !column.primary_key && !column.nullable && column.default_value.is_none() { + return Err(ConstrainedSchemaError::NotNullableColumnNeedsDefault { + tbl_name: tbl_name.clone(), + name: name.clone(), + }); + } + + if column + .raw + .constraints + .iter() + .any(|named| matches!(named.constraint, ColumnConstraint::ForeignKey { .. })) + { + return Err(ConstrainedSchemaError::ForeignKey { + tbl_name: tbl_name.clone(), + name: name.clone(), + }); + } + } + + for (name, index) in table.indexes.iter() { + if index.unique { + return Err(ConstrainedSchemaError::UniqueIndex(name.clone())); + } + } + } + + Ok(()) + } } #[derive(Debug, thiserror::Error)] @@ -103,51 +165,30 @@ pub enum SchemaError { Parse(#[from] sqlite3_parser::lexer::sql::Error), #[error("nothing to parse")] NothingParsed, - #[error("unsupported statement: {0}")] + #[error("unsupported command: {0}")] UnsupportedCmd(Cmd), - #[error("unique indexes are not supported: {0}")] - UniqueIndex(Cmd), + #[error("missing table for index (table: '{tbl_name}', index: '{name}')")] + IndexWithoutTable { tbl_name: String, name: String }, #[error("temporary tables are not supported: {0}")] TemporaryTable(Cmd), +} + +#[derive(Debug, thiserror::Error)] +pub enum ConstrainedSchemaError { + #[error("unique indexes are not supported: {0}")] + UniqueIndex(String), #[error("table as select arenot supported: {0}")] TableAsSelect(Cmd), #[error("not nullable column '{name}' on table '{tbl_name}' needs a default value for forward schema compatibility")] NotNullableColumnNeedsDefault { tbl_name: String, name: String }, #[error("foreign keys are not supported (table: '{tbl_name}', column: '{name}')")] ForeignKey { tbl_name: String, name: String }, - #[error("missing table for index (table: '{tbl_name}', index: '{name}')")] - IndexWithoutTable { tbl_name: String, name: String }, #[error("expr used as primary")] PrimaryKeyExpr, - #[error("won't drop table without the destructive flag set (table: '{0}')")] - DropTableWithoutDestructiveFlag(String), - #[error("won't drop table without the destructive flag set (table: '{0}', column: '{1}')")] - DropColumnWithoutDestructiveFlag(String, String), - #[error("can't add a primary key (table: '{0}', column: '{1}')")] - AddPrimaryKey(String, String), - #[error("can't modify primary keys (table: '{0}')")] - ModifyPrimaryKeys(String), - - #[error("tried importing an existing schema for table '{0}' due to a failed CREATE TABLE but didn't find anything (this should never happen)")] - ImportedSchemaNotFound(String), - - #[error("existing schema for table '{tbl_name}' primary keys mismatched, expected: {expected:?}, got: {got:?}")] - ImportedSchemaPkMismatch { - tbl_name: String, - expected: IndexSet, - got: IndexSet, - }, - - #[error("existing schema for table '{tbl_name}' columns mismatched, expected: {expected:?}, got: {got:?}")] - ImportedSchemaColumnsMismatch { - tbl_name: String, - expected: IndexMap, - got: IndexMap, - }, } #[allow(clippy::result_large_err)] -pub fn init_schema(conn: &Connection) -> Result { +pub fn init_schema(conn: &Connection) -> Result { let mut dump = String::new(); let tables: HashMap = conn @@ -177,12 +218,47 @@ pub fn init_schema(conn: &Connection) -> Result { parse_sql(dump.as_str()) } +#[derive(Debug, thiserror::Error)] +pub enum ApplySchemaError { + #[error(transparent)] + Rusqlite(#[from] rusqlite::Error), + #[error(transparent)] + Schema(#[from] SchemaError), + #[error(transparent)] + ConstrainedSchema(#[from] ConstrainedSchemaError), + #[error("won't drop table without the destructive flag set (table: '{0}')")] + DropTableWithoutDestructiveFlag(String), + #[error("won't drop table without the destructive flag set (table: '{0}', column: '{1}')")] + DropColumnWithoutDestructiveFlag(String, String), + #[error("can't add a primary key (table: '{0}', column: '{1}')")] + AddPrimaryKey(String, String), + #[error("can't modify primary keys (table: '{0}')")] + ModifyPrimaryKeys(String), + + #[error("tried importing an existing schema for table '{0}' due to a failed CREATE TABLE but didn't find anything (this should never happen)")] + ImportedSchemaNotFound(String), + + #[error("existing schema for table '{tbl_name}' primary keys mismatched, expected: {expected:?}, got: {got:?}")] + ImportedSchemaPkMismatch { + tbl_name: String, + expected: IndexSet, + got: IndexSet, + }, + + #[error("existing schema for table '{tbl_name}' columns mismatched, expected: {expected:?}, got: {got:?}")] + ImportedSchemaColumnsMismatch { + tbl_name: String, + expected: IndexMap, + got: IndexMap, + }, +} + #[allow(clippy::result_large_err)] pub fn apply_schema( tx: &Transaction, - schema: &NormalizedSchema, - new_schema: &mut NormalizedSchema, -) -> Result<(), SchemaError> { + schema: &Schema, + new_schema: &mut Schema, +) -> Result<(), ApplySchemaError> { if let Some(name) = schema .tables .keys() @@ -191,12 +267,12 @@ pub fn apply_schema( .next() { // TODO: add options and check flag - return Err(SchemaError::DropTableWithoutDestructiveFlag( + return Err(ApplySchemaError::DropTableWithoutDestructiveFlag( (*name).clone(), )); } - let mut schema_to_merge = NormalizedSchema::default(); + let mut schema_to_merge = Schema::default(); { let new_table_names = new_schema @@ -245,10 +321,10 @@ pub fn apply_schema( let parsed_table = parse_sql(&sql)? .tables .remove(name) - .ok_or_else(|| SchemaError::ImportedSchemaNotFound(name.clone()))?; + .ok_or_else(|| ApplySchemaError::ImportedSchemaNotFound(name.clone()))?; if parsed_table.pk != table.pk { - return Err(SchemaError::ImportedSchemaPkMismatch { + return Err(ApplySchemaError::ImportedSchemaPkMismatch { tbl_name: name.clone(), expected: table.pk.clone(), got: parsed_table.pk, @@ -256,7 +332,7 @@ pub fn apply_schema( } if parsed_table.columns != table.columns { - return Err(SchemaError::ImportedSchemaColumnsMismatch { + return Err(ApplySchemaError::ImportedSchemaColumnsMismatch { tbl_name: name.clone(), expected: table.columns.clone(), got: parsed_table.columns, @@ -327,7 +403,7 @@ pub fn apply_schema( debug!("dropped cols: {dropped_cols:?}"); if let Some(col_name) = dropped_cols.into_iter().next() { - return Err(SchemaError::DropColumnWithoutDestructiveFlag( + return Err(ApplySchemaError::DropColumnWithoutDestructiveFlag( name.clone(), col_name.clone(), )); @@ -335,7 +411,7 @@ pub fn apply_schema( // 2. check for changed columns - let changed_cols: HashMap = table + let changed_cols: HashMap = table .columns .iter() .filter_map(|(name, col)| { @@ -379,13 +455,17 @@ pub fn apply_schema( for (col_name, col) in new_cols_iter { info!("adding column '{col_name}'"); if col.primary_key { - return Err(SchemaError::AddPrimaryKey(name.clone(), col_name.clone())); + return Err(ApplySchemaError::AddPrimaryKey( + name.clone(), + col_name.clone(), + )); } if !col.nullable && col.default_value.is_none() { - return Err(SchemaError::NotNullableColumnNeedsDefault { + return Err(ConstrainedSchemaError::NotNullableColumnNeedsDefault { tbl_name: name.clone(), name: col_name.clone(), - }); + } + .into()); } tx.execute_batch(&format!("ALTER TABLE {name} ADD COLUMN {}", col))?; } @@ -415,7 +495,7 @@ pub fn apply_schema( .collect::>(); if primary_keys != new_primary_keys { - return Err(SchemaError::ModifyPrimaryKeys(name.clone())); + return Err(ApplySchemaError::ModifyPrimaryKeys(name.clone())); } // "12-step" process to modifying a table @@ -537,7 +617,7 @@ pub fn apply_schema( } #[allow(clippy::result_large_err)] -pub fn parse_sql_to_schema(schema: &mut NormalizedSchema, sql: &str) -> Result<(), SchemaError> { +pub fn parse_sql_to_schema(schema: &mut Schema, sql: &str) -> Result<(), SchemaError> { debug!("parsing {sql}"); let mut parser = sqlite3_parser::lexer::sql::Parser::new(sql.as_bytes()); @@ -549,9 +629,6 @@ pub fn parse_sql_to_schema(schema: &mut NormalizedSchema, sql: &str) -> Result<( return Err(err.into()); } Ok(Some(ref cmd @ Cmd::Stmt(ref stmt))) => match stmt { - Stmt::CreateIndex { unique: true, .. } => { - return Err(SchemaError::UniqueIndex(cmd.clone())) - } Stmt::CreateTable { temporary: true, .. } => return Err(SchemaError::TemporaryTable(cmd.clone())), @@ -570,29 +647,31 @@ pub fn parse_sql_to_schema(schema: &mut NormalizedSchema, sql: &str) -> Result<( options, }, } => { - if let Some(table) = - prepare_table(tbl_name, columns, constraints.as_ref(), options)? - { - schema.tables.insert(tbl_name.name.0.clone(), table); - debug!("inserted table: {}", tbl_name.name.0); - } else { - debug!("skipped table: {}", tbl_name.name.0); - } + schema.tables.insert( + tbl_name.name.0.clone(), + prepare_table(tbl_name, columns, constraints.as_ref(), options), + ); + debug!("inserted table: {}", tbl_name.name.0); } Stmt::CreateIndex { - unique: false, - if_not_exists: _, + unique, idx_name, tbl_name, columns, where_clause, + .. } => { if let Some(table) = schema.tables.get_mut(tbl_name.0.as_str()) { - if let Some(index) = - prepare_index(idx_name, tbl_name, columns, where_clause.as_ref())? - { - table.indexes.insert(idx_name.name.0.clone(), index); - } + table.indexes.insert( + idx_name.name.0.clone(), + Index { + name: idx_name.name.0.clone(), + tbl_name: tbl_name.0.clone(), + columns: columns.to_vec(), + where_clause: where_clause.clone(), + unique: *unique, + }, + ); } else { return Err(SchemaError::IndexWithoutTable { tbl_name: tbl_name.0.clone(), @@ -610,53 +689,21 @@ pub fn parse_sql_to_schema(schema: &mut NormalizedSchema, sql: &str) -> Result<( } #[allow(clippy::result_large_err)] -pub fn parse_sql(sql: &str) -> Result { - let mut schema = NormalizedSchema::default(); +pub fn parse_sql(sql: &str) -> Result { + let mut schema = Schema::default(); parse_sql_to_schema(&mut schema, sql)?; Ok(schema) } -#[allow(clippy::result_large_err)] -fn prepare_index( - name: &QualifiedName, - tbl_name: &Name, - columns: &[SortedColumn], - where_clause: Option<&Expr>, -) -> Result, SchemaError> { - debug!("preparing index: {}", name.name.0); - if tbl_name.0.contains("crsql") - & tbl_name.0.contains("sqlite") - & tbl_name.0.starts_with("__corro") - { - return Ok(None); - } - - Ok(Some(NormalizedIndex { - name: name.name.0.clone(), - tbl_name: tbl_name.0.clone(), - columns: columns.to_vec(), - where_clause: where_clause.cloned(), - })) -} - #[allow(clippy::result_large_err)] fn prepare_table( tbl_name: &QualifiedName, columns: &[ColumnDefinition], constraints: Option<&Vec>, options: &TableOptions, -) -> Result, SchemaError> { - debug!("preparing table: {}", tbl_name.name.0); - if tbl_name.name.0.contains("crsql") - & tbl_name.name.0.contains("sqlite") - & tbl_name.name.0.starts_with("__corro") - { - debug!("skipping table because of name"); - return Ok(None); - } - +) -> Table { let pk = constraints .and_then(|constraints| { constraints @@ -665,17 +712,17 @@ fn prepare_table( TableConstraint::PrimaryKey { columns, .. } => Some( columns .iter() - .map(|col| match &col.expr { - Expr::Id(id) => Ok(id.0.clone()), - _ => Err(SchemaError::PrimaryKeyExpr), + .filter_map(|col| match &col.expr { + Expr::Id(id) => Some(id.0.clone()), + _ => None, }) - .collect::, SchemaError>>(), + .collect::>(), ), _ => None, }) }) .unwrap_or_else(|| { - Ok(columns + columns .iter() .filter(|&def| { def.constraints.iter().any(|named| { @@ -683,10 +730,10 @@ fn prepare_table( }) }) .map(|def| def.col_name.0.clone()) - .collect()) - })?; + .collect() + }); - Ok(Some(NormalizedTable { + Table { name: tbl_name.name.0.clone(), indexes: IndexMap::new(), columns: columns @@ -714,27 +761,9 @@ fn prepare_table( let primary_key = pk.contains(&def.col_name.0); - if !primary_key && (!nullable && default_value.is_none()) { - return Err(SchemaError::NotNullableColumnNeedsDefault { - tbl_name: tbl_name.name.0.clone(), - name: def.col_name.0.clone(), - }); - } - - if def - .constraints - .iter() - .any(|named| matches!(named.constraint, ColumnConstraint::ForeignKey { .. })) - { - return Err(SchemaError::ForeignKey { - tbl_name: tbl_name.name.0.clone(), - name: def.col_name.0.clone(), - }); - } - - Ok(( + ( def.col_name.0.clone(), - NormalizedColumn { + Column { name: def.col_name.0.clone(), sql_type: match def .col_type @@ -783,14 +812,14 @@ fn prepare_table( }), raw: def.clone(), }, - )) + ) }) - .collect::, SchemaError>>()?, + .collect::>(), pk, raw: CreateTableBody::ColumnsAndConstraints { columns: columns.to_vec(), constraints: constraints.cloned(), options: *options, }, - })) + } } From b662de64ff27414a3577968b484e8786450ea99a Mon Sep 17 00:00:00 2001 From: Jerome Gravel-Niquet Date: Thu, 26 Oct 2023 14:30:15 -0400 Subject: [PATCH 04/12] this is a little insane, should support more pg clients now --- Cargo.lock | 1 + Cargo.toml | 2 +- crates/corro-agent/src/agent.rs | 9 +- crates/corro-agent/src/api/peer.rs | 1 + crates/corro-agent/src/api/public/mod.rs | 4 +- crates/corro-agent/src/broadcast/mod.rs | 4 +- crates/corro-pg/Cargo.toml | 5 +- crates/corro-pg/src/lib.rs | 815 +++++++++++++---------- crates/corro-pg/src/vtab/mod.rs | 2 + crates/corro-pg/src/vtab/pg_range.rs | 99 +++ crates/corro-pg/src/vtab/pg_type.rs | 307 +++++++++ crates/corro-types/src/schema.rs | 14 +- examples/fly/schemas/todo.sql | 5 - 13 files changed, 901 insertions(+), 367 deletions(-) create mode 100644 crates/corro-pg/src/vtab/mod.rs create mode 100644 crates/corro-pg/src/vtab/pg_range.rs create mode 100644 crates/corro-pg/src/vtab/pg_type.rs delete mode 100644 examples/fly/schemas/todo.sql diff --git a/Cargo.lock b/Cargo.lock index af7ef499..1c487b15 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -837,6 +837,7 @@ dependencies = [ "phf", "postgres-types", "rusqlite", + "spawn", "sqlite3-parser", "sqlparser", "tempfile", diff --git a/Cargo.toml b/Cargo.toml index ee61556f..b3ae1e89 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,7 +51,7 @@ rand = { version = "0.8.5", features = ["small_rng"] } rangemap = { version = "1.3.0" } rcgen = { version = "0.11.1", features = ["x509-parser"] } rhai = { version = "1.15.1", features = ["sync"] } -rusqlite = { version = "0.29.0", features = ["serde_json", "time", "bundled", "uuid", "array", "load_extension", "column_decltype"] } +rusqlite = { version = "0.29.0", features = ["serde_json", "time", "bundled", "uuid", "array", "load_extension", "column_decltype", "vtab"] } rustls = { version = "0.21.0", features = ["dangerous_configuration", "quic"] } rustls-pemfile = "1.0.2" seahash = "4.1.0" diff --git a/crates/corro-agent/src/agent.rs b/crates/corro-agent/src/agent.rs index 46861c5c..96b8c7c4 100644 --- a/crates/corro-agent/src/agent.rs +++ b/crates/corro-agent/src/agent.rs @@ -299,7 +299,10 @@ pub async fn setup(conf: Config, tripwire: Tripwire) -> eyre::Result<(Agent, Age pub async fn start(conf: Config, tripwire: Tripwire) -> eyre::Result { let (agent, opts) = setup(conf, tripwire.clone()).await?; - tokio::spawn(run(agent.clone(), opts).inspect(|_| info!("corrosion agent run is done"))); + tokio::spawn(run(agent.clone(), opts).inspect(|res| match res { + Ok(_) => info!("corrosion agent run is done"), + Err(e) => error!("running corrosion agent failed: {e}"), + })); Ok(agent) } @@ -321,7 +324,7 @@ pub async fn run(agent: Agent, opts: AgentOptions) -> eyre::Result<()> { if let Some(pg_conf) = agent.config().api.pg.clone() { info!("Starting PostgreSQL wire-compatible server"); - let pg_server = corro_pg::start(agent.clone(), pg_conf).await?; + let pg_server = corro_pg::start(agent.clone(), pg_conf, tripwire.clone()).await?; info!( "Started PostgreSQL wire-compatible server, listening at {}", pg_server.local_addr @@ -2419,6 +2422,7 @@ async fn handle_changes( } // drain and process current changes! + #[allow(clippy::drain_collect)] if let Err(e) = process_multiple_changes(&agent, buf.drain(..).collect()).await { error!("could not process multiple changes: {e}"); } @@ -2437,6 +2441,7 @@ async fn handle_changes( buf.push((change, src)); if count >= MIN_CHANGES_CHUNK { // drain and process current changes! + #[allow(clippy::drain_collect)] if let Err(e) = process_multiple_changes(&agent, buf.drain(..).collect()).await { error!("could not process multiple changes: {e}"); } diff --git a/crates/corro-agent/src/api/peer.rs b/crates/corro-agent/src/api/peer.rs index 3305cdc3..5a0a814e 100644 --- a/crates/corro-agent/src/api/peer.rs +++ b/crates/corro-agent/src/api/peer.rs @@ -1021,6 +1021,7 @@ pub async fn parallel_sync( debug!("collected member needs and such!"); + #[allow(clippy::manual_try_fold)] let syncers = results.into_iter().fold(Ok(vec![]), |agg, (actor_id, addr, res)| { match res { Ok((needs, tx, read)) => { diff --git a/crates/corro-agent/src/api/public/mod.rs b/crates/corro-agent/src/api/public/mod.rs index 9160c7c3..e849519a 100644 --- a/crates/corro-agent/src/api/public/mod.rs +++ b/crates/corro-agent/src/api/public/mod.rs @@ -14,7 +14,7 @@ use corro_types::{ api::{row_to_change, ExecResponse, ExecResult, QueryEvent, Statement}, broadcast::{ChangeV1, Changeset, Timestamp}, change::SqliteValue, - http::{IoBodyStream, LinesBytesCodec}, + http::IoBodyStream, schema::{apply_schema, parse_sql}, sqlite::SqlitePoolError, }; @@ -207,7 +207,7 @@ enum HandleConnError { } impl From> for HandleConnError { - fn from(value: SendError) -> Self { + fn from(_: SendError) -> Self { HandleConnError::EventsChannelClosed } } diff --git a/crates/corro-agent/src/broadcast/mod.rs b/crates/corro-agent/src/broadcast/mod.rs index 1170aaa7..b2bbe1fc 100644 --- a/crates/corro-agent/src/broadcast/mod.rs +++ b/crates/corro-agent/src/broadcast/mod.rs @@ -231,7 +231,7 @@ pub fn runtime_loop( .map(serde_json::Value::from) .collect::>() }) - .unwrap_or(vec![]), + .unwrap_or_default(), ), )), Err(e) => { @@ -425,7 +425,7 @@ pub fn runtime_loop( .map(serde_json::Value::from) .collect::>() }) - .unwrap_or(vec![]), + .unwrap_or_default(), ); let upserted = tx.prepare_cached("INSERT INTO __corro_members (actor_id, address, state, foca_state, rtts) diff --git a/crates/corro-pg/Cargo.toml b/crates/corro-pg/Cargo.toml index 4b1253a7..2fed1c7d 100644 --- a/crates/corro-pg/Cargo.toml +++ b/crates/corro-pg/Cargo.toml @@ -20,10 +20,11 @@ phf = "*" postgres-types = { version = "0.2", features = ["with-time-0_3"] } sqlite3-parser = { workspace = true } fallible-iterator = { workspace = true } +tripwire = { path = "../tripwire" } +tempfile = { workspace = true } [dev-dependencies] tracing-subscriber = { workspace = true } -tempfile = { workspace = true } corro-tests = { path = "../corro-tests" } tokio-postgres = { version = "0.7.10" } -tripwire = { path = "../tripwire" } \ No newline at end of file +spawn = { path = "../spawn" } \ No newline at end of file diff --git a/crates/corro-pg/src/lib.rs b/crates/corro-pg/src/lib.rs index e481f58e..5a69c61b 100644 --- a/crates/corro-pg/src/lib.rs +++ b/crates/corro-pg/src/lib.rs @@ -1,10 +1,16 @@ pub mod proto; pub mod proto_ext; pub mod sql_state; +mod vtab; -use std::{borrow::Borrow, collections::HashMap, future::poll_fn, net::SocketAddr, sync::Arc}; +use std::{ + collections::{HashMap, VecDeque}, + future::poll_fn, + net::SocketAddr, + sync::Arc, +}; -use bytes::{Buf, BytesMut}; +use bytes::Buf; use compact_str::CompactString; use corro_types::{ agent::{Agent, KnownDbVersion}, @@ -26,16 +32,16 @@ use pgwire::{ response::{ EmptyQueryResponse, ReadyForQuery, READY_STATUS_IDLE, READY_STATUS_TRANSACTION_BLOCK, }, - startup::SslRequest, + startup::{ParameterStatus, SslRequest}, PgWireBackendMessage, PgWireFrontendMessage, }, tokio::PgWireMessageServerCodec, }; use postgres_types::{FromSql, Type}; -use rusqlite::{named_params, types::ValueRef, Connection, Statement}; +use rusqlite::{named_params, types::ValueRef, vtab::eponymous_only_module, Connection, Statement}; use sqlite3_parser::ast::{ - Cmd, CreateTableBody, Expr, Id, InsertBody, Literal, Name, OneSelect, ResultColumn, Select, - SelectTable, Stmt, + As, Cmd, CreateTableBody, Expr, FromClause, Id, InsertBody, Literal, Name, OneSelect, + ResultColumn, Select, SelectTable, Stmt, }; use tokio::{ io::{AsyncReadExt, AsyncWriteExt, ReadBuf}, @@ -45,8 +51,12 @@ use tokio::{ }; use tokio_util::codec::Framed; use tracing::{debug, error, info, trace, warn}; +use tripwire::{Outcome, PreemptibleFutureExt, Tripwire}; -use crate::sql_state::SqlState; +use crate::{ + sql_state::SqlState, + vtab::{pg_range::PgRangeTable, pg_type::PgTypeTable}, +}; type BoxError = Box; @@ -127,7 +137,7 @@ impl ParsedCmd { Stmt::Insert { .. } => Tag::new_for_execution("INSERT", rows), Stmt::Pragma(_, _) => Tag::new_for_execution("PRAGMA", rows), Stmt::Reindex { .. } => Tag::new_for_execution("REINDEX", rows), - Stmt::Release(_) => Tag::new_for_execution("RELEAE", rows), + Stmt::Release(_) => Tag::new_for_execution("RELEASE", rows), Stmt::Rollback { .. } => Tag::new_for_execution("ROLLBACK", rows), Stmt::Savepoint(_) => Tag::new_for_execution("SAVEPOINT", rows), @@ -139,29 +149,14 @@ impl ParsedCmd { } } -#[derive(Clone, Debug)] -struct Query { - cmds: Vec, -} - -impl Query { - fn len(&self) -> usize { - self.cmds.len() - } - - fn is_empty(&self) -> bool { - self.cmds.is_empty() - } -} - -fn parse_query(sql: &str) -> Result { - let mut cmds = vec![]; +fn parse_query(sql: &str) -> Result, sqlite3_parser::lexer::sql::Error> { + let mut cmds = VecDeque::new(); let mut parser = sqlite3_parser::lexer::sql::Parser::new(sql.as_bytes()); loop { match parser.next() { Ok(Some(cmd)) => { - cmds.push(ParsedCmd(cmd)); + cmds.push_back(ParsedCmd(cmd)); } Ok(None) => { break; @@ -170,7 +165,7 @@ fn parse_query(sql: &str) -> Result { } } - Ok(Query { cmds }) + Ok(cmds) } enum OpenTx { @@ -215,13 +210,31 @@ async fn peek_for_sslrequest( } } -pub async fn start(agent: Agent, pg: PgConfig) -> std::io::Result { +#[derive(Debug, thiserror::Error)] +pub enum PgStartError { + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + Rusqlite(#[from] rusqlite::Error), +} + +pub async fn start( + agent: Agent, + pg: PgConfig, + mut tripwire: Tripwire, +) -> Result { let server = TcpListener::bind(pg.bind_addr).await?; let local_addr = server.local_addr()?; + // let tmp_dir = tempfile::TempDir::new()?; + // let pg_system_path = tmp_dir.path().join("pg_system.sqlite"); + tokio::spawn(async move { loop { - let (mut conn, remote_addr) = server.accept().await?; + let (mut conn, remote_addr) = match server.accept().preemptible(&mut tripwire).await { + Outcome::Completed(res) => res?, + Outcome::Preempted(_) => break, + }; info!("accepted a conn, addr: {remote_addr}"); let agent = agent.clone(); @@ -241,7 +254,6 @@ pub async fn start(agent: Agent, pg: PgConfig) -> std::io::Result { match msg { PgWireFrontendMessage::Startup(startup) => { info!("received startup message: {startup:?}"); - println!("huh..."); } _ => { framed @@ -265,6 +277,14 @@ pub async fn start(agent: Agent, pg: PgConfig) -> std::io::Result { pgwire::messages::startup::Authentication::Ok, )) .await?; + + framed + .feed(PgWireBackendMessage::ParameterStatus(ParameterStatus::new( + "server_version".into(), + "14.0.0".into(), + ))) + .await?; + framed .feed(PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new( READY_STATUS_IDLE, @@ -340,6 +360,9 @@ pub async fn start(agent: Agent, pg: PgConfig) -> std::io::Result { let conn = agent.pool().client_dedicated().unwrap(); println!("opened connection"); + conn.create_module("pg_type", eponymous_only_module::(), None)?; + conn.create_module("pg_range", eponymous_only_module::(), None)?; + let schema = match compute_schema(&conn) { Ok(schema) => schema, Err(e) => { @@ -401,8 +424,8 @@ pub async fn start(agent: Agent, pg: PgConfig) -> std::io::Result { continue; } PgWireFrontendMessage::Parse(parse) => { - let mut query = match parse_query(parse.query()) { - Ok(query) => query, + let mut cmds = match parse_query(parse.query()) { + Ok(cmds) => cmds, Err(e) => { back_tx.blocking_send( ( @@ -422,30 +445,28 @@ pub async fn start(agent: Agent, pg: PgConfig) -> std::io::Result { } }; - if query.len() > 1 { - back_tx.blocking_send( - ( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".to_owned(), - sql_state::SqlState::PROTOCOL_VIOLATION - .code() - .into(), - "only 1 command per Parse is allowed".into(), - ) + let parsed_cmd = match cmds.pop_front() { + Some(cmd) => cmd, + None => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + sql_state::SqlState::PROTOCOL_VIOLATION + .code() + .into(), + "only 1 command per Parse is allowed" + .into(), + ) + .into(), + ), + true, + ) .into(), - ), - true, - ) - .into(), - )?; - continue; - } - - let parsed_cmd = if query.is_empty() { - None - } else { - Some(query.cmds.remove(0)) + )?; + continue; + } }; println!("parsed cmd: {parsed_cmd:?}"); @@ -475,7 +496,7 @@ pub async fn start(agent: Agent, pg: PgConfig) -> std::io::Result { parse.name().as_deref().unwrap_or("").into(), ( parse.query().clone(), - parsed_cmd, + Some(parsed_cmd), prepped, parse .type_oids() @@ -488,7 +509,7 @@ pub async fn start(agent: Agent, pg: PgConfig) -> std::io::Result { back_tx.blocking_send( ( PgWireBackendMessage::ParseComplete(ParseComplete::new()), - true, + false, ) .into(), )?; @@ -638,7 +659,7 @@ pub async fn start(agent: Agent, pg: PgConfig) -> std::io::Result { fields.iter().map(Into::into).collect(), ), ), - true, + false, ) .into(), )?; @@ -1001,6 +1022,7 @@ pub async fn start(agent: Agent, pg: PgConfig) -> std::io::Result { loop { if count >= max_rows { + // forget the Rows iterator here so as to not reset the statement! std::mem::forget(rows); back_tx.blocking_send( ( @@ -1152,14 +1174,7 @@ pub async fn start(agent: Agent, pg: PgConfig) -> std::io::Result { continue; } - let mut cmd_iter = parsed_query.cmds.into_iter().peekable(); - - loop { - let cmd = match cmd_iter.next() { - None => break, - Some(cmd) => cmd, - }; - + for cmd in parsed_query.into_iter() { // need to start an implicit transaction if open_tx.is_none() && !cmd.is_begin() { if let Err(e) = conn.execute_batch("BEGIN") { @@ -1565,9 +1580,10 @@ pub async fn start(agent: Agent, pg: PgConfig) -> std::io::Result { Ok::<_, BoxError>(()) }); - return Ok(PgServer { local_addr }); + Ok(PgServer { local_addr }) } +#[allow(clippy::result_large_err)] fn name_to_type(name: &str) -> Result { match name.to_uppercase().as_ref() { "ANY" => Ok(Type::ANY), @@ -1718,11 +1734,6 @@ enum SqliteName { DoublyQualified(Name, Name, Name), } -enum ParamType { - Type(SqliteType), - FromColumn(SqliteName), -} - fn expr_to_name(expr: &Expr) -> Option { match expr { Expr::Id(id) => Some(SqliteNameRef::Id(id)), @@ -1733,35 +1744,36 @@ fn expr_to_name(expr: &Expr) -> Option { } } -fn literal_type(expr: &Expr) -> Option { - match expr { - Expr::Literal(lit) => match lit { - Literal::Numeric(num) => { - if num.parse::().is_ok() { - Some(SqliteType::Integer) - } else if num.parse::().is_ok() { - Some(SqliteType::Real) - } else { - // this should be unreachable... - None - } - } - Literal::String(_) => Some(SqliteType::Text), - Literal::Blob(_) => Some(SqliteType::Blob), - Literal::Keyword(keyword) => { - // TODO: figure out what this is... - warn!("got a keyword: {keyword}"); - None - } - Literal::Null => Some(SqliteType::Null), - Literal::CurrentDate | Literal::CurrentTime | Literal::CurrentTimestamp => { - // TODO: make this configurable at connection time or something - Some(SqliteType::Text) - } - }, - _ => None, - } -} +// determines the type of a literal type if any +// fn literal_type(expr: &Expr) -> Option { +// match expr { +// Expr::Literal(lit) => match lit { +// Literal::Numeric(num) => { +// if num.parse::().is_ok() { +// Some(SqliteType::Integer) +// } else if num.parse::().is_ok() { +// Some(SqliteType::Real) +// } else { +// // this should be unreachable... +// None +// } +// } +// Literal::String(_) => Some(SqliteType::Text), +// Literal::Blob(_) => Some(SqliteType::Blob), +// Literal::Keyword(keyword) => { +// // TODO: figure out what this is... +// warn!("got a keyword: {keyword}"); +// None +// } +// Literal::Null => Some(SqliteType::Null), +// Literal::CurrentDate | Literal::CurrentTime | Literal::CurrentTimestamp => { +// // TODO: make this configurable at connection time or something +// Some(SqliteType::Text) +// } +// }, +// _ => None, +// } +// } fn handle_lhs_rhs(lhs: &Expr, rhs: &Expr) -> Option { match ( @@ -1773,68 +1785,24 @@ fn handle_lhs_rhs(lhs: &Expr, rhs: &Expr) -> Option { } } -// fn handle_select(select: &Select, params: &mut Vec) { -// let mut aliases = HashMap::new(); -// match &select.body.select { -// OneSelect::Select { -// columns, -// from, -// where_clause, -// .. -// } => { -// // process FROM for aliases only - -// if let Some(from) = from { -// if let Some(select) = from.select.as_deref() { -// match select { -// SelectTable::Table(name, alias, _) => {} -// SelectTable::TableCall(name, _, alias) => {} -// SelectTable::Select(_, alias) => {} -// SelectTable::Sub(_, _) => {} -// } -// } -// } - -// for col in columns.iter() { -// match col { -// ResultColumn::Expr(expr, _) => { -// if is_param(expr) { -// params.push(ParamType::Type(SqliteType::Text)); -// } else { -// extract_param(expr, &aliases, params); -// } -// } -// _ => { -// // nothing to do here I don't think -// } -// } -// } - -// if let Some(from) = from { -// if let Some(select) = from.select.as_deref() { -// match select { -// SelectTable::Table(_, _, _) => {} -// SelectTable::TableCall(_, _, _) => {} -// SelectTable::Select(_, _) => {} -// SelectTable::Sub(_, _) => {} -// } -// } -// } -// } -// OneSelect::Values(_) => { -// // TODO: handle values -// } -// } -// } - -fn extract_param(expr: &Expr, tables: &HashMap, params: &mut Vec) { +fn extract_param( + schema: &Schema, + expr: &Expr, + tables: &HashMap, + params: &mut Vec, +) { match expr { + // expr BETWEEN expr AND expr Expr::Between { - lhs, start, end, .. + lhs: _, + start: _, + end: _, + not: _, } => {} + + // expr operator expr Expr::Binary(lhs, _, rhs) => { if let Some(name) = handle_lhs_rhs(lhs, rhs) { - println!("HANDLED LHS RHS: {name:?}"); match name { // not aliased! SqliteName::Id(id) => { @@ -1858,127 +1826,257 @@ fn extract_param(expr: &Expr, tables: &HashMap, params: &mut Vec } } } + + // CASE expr [WHEN expr THEN expr, ..., ELSE expr] Expr::Case { - base, - when_then_pairs, - else_expr, + base: _, + when_then_pairs: _, + else_expr: _, + } => {} + + // CAST ( expr AS type-name ) + Expr::Cast { + expr: _, + type_name: _, } => {} - Expr::Cast { expr, type_name } => {} + + // expr COLLATE collation-name Expr::Collate(_, _) => {} + + // schema-name.table-name.column-name Expr::DoublyQualified(_, _, _) => {} - Expr::Exists(_) => {} + + // EXISTS ( select ) + Expr::Exists(select) => handle_select(schema, select, params), + + // function-name ( [DISTINCT] expr, ... ) filter-clause over-clause Expr::FunctionCall { - name, - distinctness, - args, - filter_over, + name: _, + distinctness: _, + args: _, + filter_over: _, + } => {} + + Expr::FunctionCallStar { + name: _, + filter_over: _, } => {} - Expr::FunctionCallStar { name, filter_over } => {} + + // id Expr::Id(_) => {} - Expr::InList { lhs, not, rhs } => {} - Expr::InSelect { lhs, not, rhs } => {} - Expr::InTable { - lhs, - not, + + // expr IN ( expr, ... ) + Expr::InList { lhs, not: _, rhs } => { + if let Some(rhs) = rhs { + for expr in rhs.iter() { + if let Some(name) = handle_lhs_rhs(lhs, expr) { + println!("HANDLED LHS RHS: {name:?}"); + match name { + // not aliased! + SqliteName::Id(id) => { + // find the first one to match + for (_, table) in tables.iter() { + if let Some(col) = table.columns.get(&id.0) { + params.push(col.sql_type); + break; + } + } + } + SqliteName::Name(_) => {} + SqliteName::Qualified(tbl_name, col_name) + | SqliteName::DoublyQualified(_, tbl_name, col_name) => { + if let Some(table) = tables.get(&tbl_name.0) { + if let Some(col) = table.columns.get(&col_name.0) { + params.push(col.sql_type); + } + } + } + } + } + } + } + } + + // expr IN ( select ) + Expr::InSelect { + lhs: _, + not: _, rhs, - args, + } => { + // TODO: check LHS here + handle_select(schema, rhs.as_ref(), params); + } + + // expr IN schema-name.table-name | schema-name.table-function ( expr, ... ) + Expr::InTable { + lhs: _, + not: _, + rhs: _, + args: _, } => {} + + // expr IS NULL Expr::IsNull(_) => {} + + // expr [NOT] LIKE | GLOB | REGEXP | MATCH expr Expr::Like { - lhs, - not, - op, - rhs, - escape, + lhs: _, + not: _, + op: _, + rhs: _, + escape: _, } => {} - Expr::Literal(_) => {} + + // NULL | integer | float | text | blob + Expr::Literal(_) => { + // nothing to do + } + + // TODO: Expr::Name(_) => {} + + // expr NOT NULL Expr::NotNull(_) => {} - Expr::Parenthesized(_) => {} + + // ( expr, ... ) + Expr::Parenthesized(exprs) => { + for expr in exprs.iter() { + extract_param(schema, expr, tables, params) + } + } + + // schema-name.table-name Expr::Qualified(_, _) => {} + + // RAISE ( IGNORE | ROLLBACK | ABORT | FAIL [ error ] ) Expr::Raise(_, _) => {} - Expr::Subquery(_) => {} + + // SELECT + Expr::Subquery(select) => handle_select(schema, select, params), + + // NOT | ~ | - | + expr Expr::Unary(_, _) => {} + + // ? | $ | : Expr::Variable(_) => {} } } -fn parameter_types(schema: &Schema, stmt: &Stmt) -> Vec { - let mut params = vec![]; +fn handle_select(schema: &Schema, select: &Select, params: &mut Vec) { + match &select.body.select { + OneSelect::Select { + columns, + from, + where_clause, + distinctness: _, + group_by: _, + window_clause: _, + } => { + if let Some(from) = from { + let tables = handle_from(schema, from, params); + if let Some(where_clause) = where_clause { + println!("WHERE CLAUSE: {where_clause:?}"); + extract_param(schema, where_clause, &tables, params); + } + } + for col in columns.iter() { + if let ResultColumn::Expr(expr, _) = col { + // TODO: check against table if we can... + if is_param(expr) { + params.push(SqliteType::Text); + } + } + } + } + OneSelect::Values(values_values) => { + for values in values_values.iter() { + for value in values.iter() { + if is_param(value) { + params.push(SqliteType::Text); + } + } + } + } + } +} - match stmt { - Stmt::Select(select) => match &select.body.select { - OneSelect::Select { - columns, - from, - where_clause, - .. - } => { - let mut tables: HashMap = HashMap::new(); - if let Some(from) = from { - if let Some(select) = from.select.as_deref() { - match select { - SelectTable::Table(qname, maybe_alias, _) => { - if let Some(alias) = maybe_alias { - let alias = match alias { - sqlite3_parser::ast::As::As(name) => name.0.clone(), - sqlite3_parser::ast::As::Elided(name) => name.0.clone(), - }; - if let Some(table) = schema.tables.get(&qname.name.0) { - tables.insert(alias, table); - } - } else { - // not aliased - if let Some(table) = schema.tables.get(&qname.name.0) { - tables.insert(qname.name.0.clone(), table); - } - } - } - SelectTable::TableCall(_, _, _) => {} - SelectTable::Select(_, _) => {} - SelectTable::Sub(_, _) => {} - } +fn handle_from<'a>( + schema: &'a Schema, + from: &FromClause, + params: &mut Vec, +) -> HashMap { + let mut tables: HashMap = HashMap::new(); + if let Some(select) = from.select.as_deref() { + match select { + SelectTable::Table(qname, maybe_alias, _) => { + if let Some(alias) = maybe_alias { + let alias = match alias { + As::As(name) => name.0.clone(), + As::Elided(name) => name.0.clone(), + }; + if let Some(table) = schema.tables.get(&qname.name.0) { + tables.insert(alias, table); } - if let Some(joins) = &from.joins { - for join in joins.iter() { - match &join.table { - SelectTable::Table(qname, maybe_alias, _) => { - if let Some(alias) = maybe_alias { - let alias = match alias { - sqlite3_parser::ast::As::As(name) => name.0.clone(), - sqlite3_parser::ast::As::Elided(name) => name.0.clone(), - }; - if let Some(table) = schema.tables.get(&qname.name.0) { - tables.insert(alias, table); - } - } else { - // not aliased - if let Some(table) = schema.tables.get(&qname.name.0) { - tables.insert(qname.name.0.clone(), table); - } - } - } - SelectTable::TableCall(_, _, _) => {} - SelectTable::Select(_, _) => {} - SelectTable::Sub(_, _) => {} - } + } else { + // not aliased + if let Some(table) = schema.tables.get(&qname.name.0) { + tables.insert(qname.name.0.clone(), table); + } + } + } + SelectTable::TableCall(_, _, _) => {} + SelectTable::Select(select, _) => { + handle_select(schema, select, params); + } + SelectTable::Sub(_, _) => {} + } + } + if let Some(joins) = &from.joins { + for join in joins.iter() { + match &join.table { + SelectTable::Table(qname, maybe_alias, _) => { + if let Some(alias) = maybe_alias { + let alias = match alias { + As::As(name) => name.0.clone(), + As::Elided(name) => name.0.clone(), + }; + if let Some(table) = schema.tables.get(&qname.name.0) { + tables.insert(alias, table); + } + } else { + // not aliased + if let Some(table) = schema.tables.get(&qname.name.0) { + tables.insert(qname.name.0.clone(), table); } } } - if let Some(where_clause) = where_clause { - println!("WHERE CLAUSE: {where_clause:?}"); - extract_param(where_clause, &tables, &mut params); + SelectTable::TableCall(_, _, _) => {} + SelectTable::Select(select, _) => { + handle_select(schema, select, params); } + SelectTable::Sub(_, _) => {} } - OneSelect::Values(_) => { - // TODO: handle this somehow... - } - }, + } + } + tables +} + +fn parameter_types(schema: &Schema, stmt: &Stmt) -> Vec { + let mut params = vec![]; + + match stmt { + Stmt::Select(select) => handle_select(schema, select, &mut params), Stmt::Delete { tbl_name, where_clause: Some(where_clause), .. - } => {} + } => { + let mut tables = HashMap::new(); + if let Some(tbl) = schema.tables.get(&tbl_name.name.0) { + tables.insert(tbl_name.name.0.clone(), tbl); + } + extract_param(schema, where_clause, &tables, &mut params); + } Stmt::Insert { tbl_name, columns, @@ -1988,29 +2086,29 @@ fn parameter_types(schema: &Schema, stmt: &Stmt) -> Vec { println!("GOT AN INSERT TO {tbl_name:?} on columns: {columns:?} w/ body: {body:?}"); if let Some(table) = schema.tables.get(&tbl_name.name.0) { match body { - InsertBody::Select(select, _) => match &select.body.select { - OneSelect::Select { - distinctness, - columns, - from, - where_clause, - group_by, - window_clause, - } => { - // handle this at some point... like any other select! - } - OneSelect::Values(values_values) => { + InsertBody::Select(select, _) => { + if let OneSelect::Values(values_values) = &select.body.select { for values in values_values.iter() { for (i, expr) in values.iter().enumerate() { if is_param(expr) { - if let Some((name, col)) = table.columns.get_index(i) { + // specified columns + let col = if let Some(columns) = columns { + columns + .get(i) + .and_then(|name| table.columns.get(&name.0)) + } else { + table.columns.get_index(i).map(|(_name, col)| col) + }; + if let Some(col) = col { params.push(col.sql_type); } } } } + } else { + handle_select(schema, select, &mut params) } - }, + } InsertBody::DefaultValues => { // nothing to do! } @@ -2018,17 +2116,30 @@ fn parameter_types(schema: &Schema, stmt: &Stmt) -> Vec { } } Stmt::Update { - with, - or_conflict, + with: _, + or_conflict: _, tbl_name, - indexed, - sets, + indexed: _, + sets: _, from, where_clause, - returning, - order_by, - limit, - } => {} + returning: _, + order_by: _, + limit: _, + } => { + let mut tables = if let Some(from) = from { + handle_from(schema, from, &mut params) + } else { + Default::default() + }; + if let Some(tbl) = schema.tables.get(&tbl_name.name.0) { + tables.insert(tbl_name.name.0.clone(), tbl); + } + if let Some(where_clause) = where_clause { + println!("WHERE CLAUSE: {where_clause:?}"); + extract_param(schema, where_clause, &tables, &mut params); + } + } _ => { // do nothing, there can't be bound params here! } @@ -2040,7 +2151,7 @@ fn parameter_types(schema: &Schema, stmt: &Stmt) -> Vec { #[cfg(test)] mod tests { use corro_tests::launch_test_agent; - use postgres_types::ToSql; + use spawn::wait_for_all_pending_handles; use tokio_postgres::NoTls; use tripwire::Tripwire; @@ -2051,13 +2162,14 @@ mod tests { _ = tracing_subscriber::fmt::try_init(); let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple(); - let ta = launch_test_agent(|builder| builder.build(), tripwire).await?; + let ta = launch_test_agent(|builder| builder.build(), tripwire.clone()).await?; let server = start( ta.agent.clone(), PgConfig { bind_addr: "127.0.0.1:0".parse()?, }, + tripwire, ) .await?; @@ -2067,85 +2179,96 @@ mod tests { server.local_addr.port() ); - let (mut client, client_conn) = tokio_postgres::connect(&conn_str, NoTls).await?; - // let (mut client, client_conn) = - // tokio_postgres::connect("host=localhost port=5432 user=jerome", NoTls).await?; - println!("client is ready!"); - tokio::spawn(client_conn); - - println!("before prepare"); - let stmt = client.prepare("SELECT 1").await?; - println!( - "after prepare: params: {:?}, columns: {:?}", - stmt.params(), - stmt.columns() - ); - - println!("before query"); - let rows = client.query(&stmt, &[]).await?; - - println!("rows count: {}", rows.len()); - for row in rows { - println!("ROW!!! {row:?}"); - } + { + let (mut client, client_conn) = tokio_postgres::connect(&conn_str, NoTls).await?; + // let (mut client, client_conn) = + // tokio_postgres::connect("host=localhost port=5432 user=jerome", NoTls).await?; + println!("client is ready!"); + tokio::spawn(client_conn); + + println!("before prepare"); + let stmt = client.prepare("SELECT 1").await?; + println!( + "after prepare: params: {:?}, columns: {:?}", + stmt.params(), + stmt.columns() + ); + + println!("before query"); + let rows = client.query(&stmt, &[]).await?; + + println!("rows count: {}", rows.len()); + for row in rows { + println!("ROW!!! {row:?}"); + } - println!("before execute"); - let affected = client - .execute("INSERT INTO tests VALUES (1,2)", &[]) - .await?; - println!("after execute, affected: {affected}"); - - let row = client.query_one("SELECT * FROM crsql_changes", &[]).await?; - println!("CHANGE ROW: {row:?}"); - - client - .batch_execute("SELECT 1; SELECT 2; SELECT 3;") - .await?; - println!("after batch exec"); - - client.batch_execute("SELECT 1; BEGIN; SELECT 3;").await?; - println!("after batch exec 2"); - - client.batch_execute("SELECT 3; COMMIT; SELECT 3;").await?; - println!("after batch exec 3"); - - let tx = client.transaction().await?; - println!("after begin I assume"); - let res = tx - .execute( - "INSERT INTO tests VALUES ($1, $2)", - &[&2i64, &"hello world"], - ) - .await?; - println!("res (rows affected): {res}"); - let res = tx - .execute( - "INSERT INTO tests2 VALUES ($1, $2)", - &[&2i64, &"hello world 2"], - ) - .await?; - println!("res (rows affected): {res}"); - tx.commit().await?; - println!("after commit"); - - let row = client - .query_one("SELECT * FROM tests t WHERE t.id = ?", &[&2i64]) - .await?; - println!("ROW: {row:?}"); - - let row = client - .query_one("SELECT * FROM tests t WHERE t.id = ?", &[&2i64]) - .await?; - println!("ROW: {row:?}"); - - let row = client + println!("before execute"); + let affected = client + .execute("INSERT INTO tests VALUES (1,2)", &[]) + .await?; + println!("after execute, affected: {affected}"); + + let row = client.query_one("SELECT * FROM crsql_changes", &[]).await?; + println!("CHANGE ROW: {row:?}"); + + client + .batch_execute("SELECT 1; SELECT 2; SELECT 3;") + .await?; + println!("after batch exec"); + + client.batch_execute("SELECT 1; BEGIN; SELECT 3;").await?; + println!("after batch exec 2"); + + client.batch_execute("SELECT 3; COMMIT; SELECT 3;").await?; + println!("after batch exec 3"); + + let tx = client.transaction().await?; + println!("after begin I assume"); + let res = tx + .execute( + "INSERT INTO tests VALUES ($1, $2)", + &[&2i64, &"hello world"], + ) + .await?; + println!("res (rows affected): {res}"); + let res = tx + .execute( + "INSERT INTO tests2 VALUES ($1, $2)", + &[&2i64, &"hello world 2"], + ) + .await?; + println!("res (rows affected): {res}"); + tx.commit().await?; + println!("after commit"); + + let row = client + .query_one("SELECT * FROM tests t WHERE t.id = ?", &[&2i64]) + .await?; + println!("ROW: {row:?}"); + + let row = client + .query_one("SELECT * FROM tests t WHERE t.id = ?", &[&2i64]) + .await?; + println!("ROW: {row:?}"); + + let row = client + .query_one("SELECT * FROM tests t WHERE t.id IN (?)", &[&2i64]) + .await?; + println!("ROW: {row:?}"); + + let row = client .query_one("SELECT t.id, t.text, t2.text as t2text FROM tests t LEFT JOIN tests2 t2 WHERE t.id = ?", &[&2i64]) .await?; - println!("ROW: {row:?}"); + println!("ROW: {row:?}"); + + println!("t.id: {:?}", row.try_get::<_, i64>(0)); + println!("t.text: {:?}", row.try_get::<_, String>(1)); + println!("t2text: {:?}", row.try_get::<_, String>(2)); + } - println!("t.id: {:?}", row.try_get::<_, i64>(0)); - println!("t.text: {:?}", row.try_get::<_, String>(1)); - println!("t2text: {:?}", row.try_get::<_, String>(2)); + tripwire_tx.send(()).await.ok(); + tripwire_worker.await; + wait_for_all_pending_handles().await; Ok(()) } diff --git a/crates/corro-pg/src/vtab/mod.rs b/crates/corro-pg/src/vtab/mod.rs new file mode 100644 index 00000000..d5ce8e87 --- /dev/null +++ b/crates/corro-pg/src/vtab/mod.rs @@ -0,0 +1,2 @@ +pub mod pg_range; +pub mod pg_type; diff --git a/crates/corro-pg/src/vtab/pg_range.rs b/crates/corro-pg/src/vtab/pg_range.rs new file mode 100644 index 00000000..62c49b13 --- /dev/null +++ b/crates/corro-pg/src/vtab/pg_range.rs @@ -0,0 +1,99 @@ +use std::{marker::PhantomData, os::raw::c_int}; + +use rusqlite::vtab::{ + sqlite3_vtab, sqlite3_vtab_cursor, IndexInfo, VTab, VTabConnection, VTabCursor, Values, +}; + +#[repr(C)] +pub struct PgRangeTable { + /// Base class. Must be first + base: sqlite3_vtab, +} + +unsafe impl<'vtab> VTab<'vtab> for PgRangeTable { + type Aux = (); + type Cursor = PgRangeTableCursor<'vtab>; + + fn connect( + _: &mut VTabConnection, + _aux: Option<&()>, + args: &[&[u8]], + ) -> rusqlite::Result<(String, PgRangeTable)> { + let vtab = PgRangeTable { + base: sqlite3_vtab::default(), + }; + + for arg in args { + println!("arg {:?}", std::str::from_utf8(arg)); + } + + Ok(( + "CREATE TABLE pg_range ( + rngtypid INTEGER, + rngsubtype INTEGER, + rngmultitypid INTEGER, + rngcollation INTEGER, + rngsubopc INTEGER, + rngcanonical TEXT, + rngsubdiff TEXT + )" + .into(), + vtab, + )) + } + + fn best_index(&self, info: &mut IndexInfo) -> rusqlite::Result<()> { + info.set_estimated_cost(1.); + Ok(()) + } + + fn open(&'vtab mut self) -> rusqlite::Result> { + Ok(PgRangeTableCursor::default()) + } +} + +#[derive(Default)] +#[repr(C)] +pub struct PgRangeTableCursor<'vtab> { + /// Base class. Must be first + base: sqlite3_vtab_cursor, + /// The rowid + row_id: i64, + phantom: PhantomData<&'vtab PgRangeTable>, +} + +// {"rngtypid":"3904","rngsubtype":"23","rngmultitypid":"4451","rngcollation":"0","rngsubopc":"1978","rngcanonical":"int4range_canonical","rngsubdiff":"int4range_subdiff"} +// {"rngtypid":"3906","rngsubtype":"1700","rngmultitypid":"4532","rngcollation":"0","rngsubopc":"3125","rngcanonical":"-","rngsubdiff":"numrange_subdiff"} +// {"rngtypid":"3908","rngsubtype":"1114","rngmultitypid":"4533","rngcollation":"0","rngsubopc":"3128","rngcanonical":"-","rngsubdiff":"tsrange_subdiff"} +// {"rngtypid":"3910","rngsubtype":"1184","rngmultitypid":"4534","rngcollation":"0","rngsubopc":"3127","rngcanonical":"-","rngsubdiff":"tstzrange_subdiff"} +// {"rngtypid":"3912","rngsubtype":"1082","rngmultitypid":"4535","rngcollation":"0","rngsubopc":"3122","rngcanonical":"daterange_canonical","rngsubdiff":"daterange_subdiff"} +// {"rngtypid":"3926","rngsubtype":"20","rngmultitypid":"4536","rngcollation":"0","rngsubopc":"3124","rngcanonical":"int8range_canonical","rngsubdiff":"int8range_subdiff"} + +unsafe impl VTabCursor for PgRangeTableCursor<'_> { + fn filter( + &mut self, + _idx_num: c_int, + _idx_str: Option<&str>, + _args: &Values<'_>, + ) -> rusqlite::Result<()> { + self.row_id = 1; + Ok(()) + } + + fn next(&mut self) -> rusqlite::Result<()> { + self.row_id += 1; + Ok(()) + } + + fn eof(&self) -> bool { + true // no rows... + } + + fn column(&self, _ctx: &mut rusqlite::vtab::Context, _col: c_int) -> rusqlite::Result<()> { + Ok(()) + } + + fn rowid(&self) -> rusqlite::Result { + Ok(self.row_id) + } +} diff --git a/crates/corro-pg/src/vtab/pg_type.rs b/crates/corro-pg/src/vtab/pg_type.rs new file mode 100644 index 00000000..3340a18b --- /dev/null +++ b/crates/corro-pg/src/vtab/pg_type.rs @@ -0,0 +1,307 @@ +use std::{marker::PhantomData, os::raw::c_int}; + +use postgres_types::Type; +use rusqlite::vtab::{ + sqlite3_vtab, sqlite3_vtab_cursor, IndexInfo, VTab, VTabConnection, VTabCursor, Values, +}; + +#[repr(C)] +pub struct PgTypeTable { + /// Base class. Must be first + base: sqlite3_vtab, +} + +unsafe impl<'vtab> VTab<'vtab> for PgTypeTable { + type Aux = (); + type Cursor = PgTypeTableCursor<'vtab>; + + fn connect( + _: &mut VTabConnection, + _aux: Option<&()>, + args: &[&[u8]], + ) -> rusqlite::Result<(String, PgTypeTable)> { + let vtab = PgTypeTable { + base: sqlite3_vtab::default(), + }; + + for arg in args { + println!("arg {:?}", std::str::from_utf8(arg)); + } + + Ok(( + "CREATE TABLE pg_type ( + oid INTEGER, + typname TEXT, + typnamespace INTEGER, + typowner INTEGER, + typlen INTEGER, + typbyval INTEGER, + typtype TEXT, + typcategory TEXT, + typispreferred INTEGER, + typisdefined INTEGER, + typdelim TEXT, + typrelid INTEGER, + typelem INTEGER, + typarray INTEGER, + typinput TEXT, + typoutput TEXT, + typreceive TEXT, + typsend TEXT, + typmodin TEXT, + typmodout TEXT, + typanalyze TEXT, + typalign TEXT, + typstorage TEXT, + typnotnull INTEGER, + typbasetype INTEGER, + typtypmod INTEGER, + typndims INTEGER, + typcollation INTEGER, + typdefaultbin TEXT, + typdefault TEXT, + typacl TEXT + )" + .into(), + vtab, + )) + } + + fn best_index(&self, info: &mut IndexInfo) -> rusqlite::Result<()> { + info.set_estimated_cost(1.); + Ok(()) + } + + fn open(&'vtab mut self) -> rusqlite::Result> { + Ok(PgTypeTableCursor::default()) + } +} + +#[derive(Default)] +#[repr(C)] +pub struct PgTypeTableCursor<'vtab> { + /// Base class. Must be first + base: sqlite3_vtab_cursor, + /// The rowid + row_id: i64, + phantom: PhantomData<&'vtab PgTypeTable>, +} + +struct PgType(Type); + +impl PgType { + fn oid(&self) -> u32 { + self.0.oid() + } + + fn typname(&self) -> &str { + self.0.name() + } + + fn typnamespace(&self) -> &'static str { + "11" + } + fn typowner(&self) -> &'static str { + "10" + } + fn typlen(&self) -> i16 { + match self.0 { + Type::BOOL => 1, + Type::BYTEA => -1, + Type::INT2 => 2, + Type::INT4 => 4, + Type::INT8 => 8, + Type::TEXT => -1, + Type::VARCHAR => -1, + Type::FLOAT4 => 4, + Type::FLOAT8 => 8, + _ => todo!(), + } + } + fn typbyval(&self) -> bool { + match self.0 { + Type::BOOL => true, + Type::BYTEA => false, + Type::INT2 => true, + Type::INT4 => true, + Type::INT8 => true, + Type::TEXT => false, + Type::VARCHAR => false, + Type::FLOAT4 => true, + Type::FLOAT8 => true, + _ => todo!(), + } + } + fn typtype(&self) -> &'static str { + "b" + } + fn typcategory(&self) -> &'static str { + match self.0 { + Type::BOOL => "B", + Type::BYTEA => "U", + Type::INT2 => "N", + Type::INT4 => "N", + Type::INT8 => "N", + Type::TEXT => "S", + Type::VARCHAR => "S", + Type::FLOAT4 => "N", + Type::FLOAT8 => "N", + _ => todo!(), + } + } + fn typispreferred(&self) -> bool { + todo!() + } + fn typisdefined(&self) -> bool { + true + } + fn typdelim(&self) -> &'static str { + todo!() + } + fn typrelid(&self) -> &'static str { + todo!() + } + fn typelem(&self) -> &'static str { + "0" + } + fn typarray(&self) -> &'static str { + todo!() + } + fn typinput(&self) -> String { + format!("{}in", self.0.name()) + } + fn typoutput(&self) -> String { + format!("{}out", self.0.name()) + } + fn typreceive(&self) -> String { + format!("{}recv", self.0.name()) + } + fn typsend(&self) -> String { + format!("{}send", self.0.name()) + } + fn typmodin(&self) -> &'static str { + todo!() + } + fn typmodout(&self) -> &'static str { + todo!() + } + fn typanalyze(&self) -> &'static str { + "-" + } + fn typalign(&self) -> &'static str { + todo!() + } + fn typstorage(&self) -> &'static str { + todo!() + } + fn typnotnull(&self) -> bool { + false + } + fn typbasetype(&self) -> &'static str { + "0" + } + fn typtypmod(&self) -> i32 { + -1 + } + fn typndims(&self) -> i32 { + 0 + } + fn typcollation(&self) -> &'static str { + todo!() + } + fn typdefaultbin(&self) -> rusqlite::types::Null { + rusqlite::types::Null + } + fn typdefault(&self) -> Option<&'static str> { + None + } + fn typacl(&self) -> rusqlite::types::Null { + rusqlite::types::Null + } +} + +const PG_TYPES: &[PgType] = &[ + // TINY INT + PgType(Type::BOOL), + // BLOB + PgType(Type::BYTEA), + // INTS + PgType(Type::INT2), + PgType(Type::INT4), + PgType(Type::INT8), + // TEXT + PgType(Type::TEXT), + PgType(Type::VARCHAR), + // REAL + PgType(Type::FLOAT4), + PgType(Type::FLOAT8), +]; + +unsafe impl VTabCursor for PgTypeTableCursor<'_> { + fn filter( + &mut self, + _idx_num: c_int, + _idx_str: Option<&str>, + _args: &Values<'_>, + ) -> rusqlite::Result<()> { + self.row_id = 0; + Ok(()) + } + + fn next(&mut self) -> rusqlite::Result<()> { + self.row_id += 1; + Ok(()) + } + + fn eof(&self) -> bool { + self.row_id >= PG_TYPES.len() as i64 + } + + fn column(&self, ctx: &mut rusqlite::vtab::Context, col: c_int) -> rusqlite::Result<()> { + if let Some(pg_type) = PG_TYPES.get(self.row_id as usize) { + match col { + 0 => ctx.set_result(&pg_type.oid()), + 1 => ctx.set_result(&pg_type.typname()), // pg_type.typname + 2 => ctx.set_result(&pg_type.typnamespace()), // pg_type.typnamespace + 3 => ctx.set_result(&pg_type.typowner()), // pg_type.typowner + 4 => ctx.set_result(&pg_type.typlen()), // pg_type.typlen + 5 => ctx.set_result(&pg_type.typbyval()), // pg_type.typbyval + 6 => ctx.set_result(&pg_type.typtype()), // pg_type.typtype + 7 => ctx.set_result(&pg_type.typcategory()), // pg_type.typcategory + 8 => ctx.set_result(&pg_type.typispreferred()), // pg_type.typispreferred + 9 => ctx.set_result(&pg_type.typisdefined()), // pg_type.typisdefined + 10 => ctx.set_result(&pg_type.typdelim()), // pg_type.typdelim + 11 => ctx.set_result(&pg_type.typrelid()), // pg_type.typrelid + 12 => ctx.set_result(&pg_type.typelem()), // pg_type.typelem + 13 => ctx.set_result(&pg_type.typarray()), // pg_type.typarray + 14 => ctx.set_result(&pg_type.typinput()), // pg_type.typinput + 15 => ctx.set_result(&pg_type.typoutput()), // pg_type.typoutput + 16 => ctx.set_result(&pg_type.typreceive()), // pg_type.typreceive + 17 => ctx.set_result(&pg_type.typsend()), // pg_type.typsend + 18 => ctx.set_result(&pg_type.typmodin()), // pg_type.typmodin + 19 => ctx.set_result(&pg_type.typmodout()), // pg_type.typmodout + 20 => ctx.set_result(&pg_type.typanalyze()), // pg_type.typanalyze + 21 => ctx.set_result(&pg_type.typalign()), // pg_type.typalign + 22 => ctx.set_result(&pg_type.typstorage()), // pg_type.typstorage + 23 => ctx.set_result(&pg_type.typnotnull()), // pg_type.typnotnull + 24 => ctx.set_result(&pg_type.typbasetype()), // pg_type.typbasetype + 25 => ctx.set_result(&pg_type.typtypmod()), // pg_type.typtypmod + 26 => ctx.set_result(&pg_type.typndims()), // pg_type.typndims + 27 => ctx.set_result(&pg_type.typcollation()), // pg_type.typcollation + 28 => ctx.set_result(&pg_type.typdefaultbin()), // pg_type.typdefaultbin + 29 => ctx.set_result(&pg_type.typdefault()), // pg_type.typdefault + 30 => ctx.set_result(&pg_type.typacl()), // pg_type.typacl + _ => Err(rusqlite::Error::InvalidColumnIndex(col as usize)), + } + } else { + Err(rusqlite::Error::ModuleError(format!( + "pg type out of bound (row id: {})", + self.row_id + ))) + } + } + + fn rowid(&self) -> rusqlite::Result { + Ok(self.row_id) + } +} diff --git a/crates/corro-types/src/schema.rs b/crates/corro-types/src/schema.rs index e241c2fa..3352c826 100644 --- a/crates/corro-types/src/schema.rs +++ b/crates/corro-types/src/schema.rs @@ -12,7 +12,7 @@ use sqlite3_parser::ast::{ Cmd, ColumnConstraint, ColumnDefinition, CreateTableBody, Expr, Name, NamedTableConstraint, QualifiedName, SortedColumn, Stmt, TableConstraint, TableOptions, ToTokens, }; -use tracing::{debug, info}; +use tracing::{debug, info, trace}; #[derive(Debug, Clone, Eq, PartialEq)] pub struct Column { @@ -96,16 +96,16 @@ pub struct Schema { impl Schema { pub fn constrain(&mut self) -> Result<(), ConstrainedSchemaError> { - self.tables.retain(|name, table| { + self.tables.retain(|name, _table| { !(name.contains("crsql") && name.contains("sqlite") && name.starts_with("__corro")) }); for (tbl_name, table) in self.tables.iter() { // this should always be the case... if let CreateTableBody::ColumnsAndConstraints { - columns, + columns: _, constraints, - options, + options: _, } = &table.raw { if let Some(constraints) = constraints { @@ -618,7 +618,7 @@ pub fn apply_schema( #[allow(clippy::result_large_err)] pub fn parse_sql_to_schema(schema: &mut Schema, sql: &str) -> Result<(), SchemaError> { - debug!("parsing {sql}"); + trace!("parsing {sql}"); let mut parser = sqlite3_parser::lexer::sql::Parser::new(sql.as_bytes()); loop { @@ -651,7 +651,7 @@ pub fn parse_sql_to_schema(schema: &mut Schema, sql: &str) -> Result<(), SchemaE tbl_name.name.0.clone(), prepare_table(tbl_name, columns, constraints.as_ref(), options), ); - debug!("inserted table: {}", tbl_name.name.0); + trace!("inserted table: {}", tbl_name.name.0); } Stmt::CreateIndex { unique, @@ -739,7 +739,7 @@ fn prepare_table( columns: columns .iter() .map(|def| { - debug!("visiting column: {}", def.col_name.0); + trace!("visiting column: {}", def.col_name.0); let default_value = def.constraints.iter().find_map(|named| { if let ColumnConstraint::Default(ref expr) = named.constraint { Some(expr.to_string()) diff --git a/examples/fly/schemas/todo.sql b/examples/fly/schemas/todo.sql deleted file mode 100644 index e3ae4d95..00000000 --- a/examples/fly/schemas/todo.sql +++ /dev/null @@ -1,5 +0,0 @@ -CREATE TABLE todos ( - id BLOB PRIMARY KEY, - title TEXT NOT NULL DEFAULT '', - completed_at INTEGER -); \ No newline at end of file From 8a2742503d9543f71f83992c37d67155f053b42b Mon Sep 17 00:00:00 2001 From: Jerome Gravel-Niquet Date: Thu, 26 Oct 2023 14:55:27 -0400 Subject: [PATCH 05/12] implement an extra column, make it less panicky --- crates/corro-pg/src/vtab/pg_range.rs | 7 ----- crates/corro-pg/src/vtab/pg_type.rs | 43 +++++++++++++++++++--------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/crates/corro-pg/src/vtab/pg_range.rs b/crates/corro-pg/src/vtab/pg_range.rs index 62c49b13..cc702a42 100644 --- a/crates/corro-pg/src/vtab/pg_range.rs +++ b/crates/corro-pg/src/vtab/pg_range.rs @@ -62,13 +62,6 @@ pub struct PgRangeTableCursor<'vtab> { phantom: PhantomData<&'vtab PgRangeTable>, } -// {"rngtypid":"3904","rngsubtype":"23","rngmultitypid":"4451","rngcollation":"0","rngsubopc":"1978","rngcanonical":"int4range_canonical","rngsubdiff":"int4range_subdiff"} -// {"rngtypid":"3906","rngsubtype":"1700","rngmultitypid":"4532","rngcollation":"0","rngsubopc":"3125","rngcanonical":"-","rngsubdiff":"numrange_subdiff"} -// {"rngtypid":"3908","rngsubtype":"1114","rngmultitypid":"4533","rngcollation":"0","rngsubopc":"3128","rngcanonical":"-","rngsubdiff":"tsrange_subdiff"} -// {"rngtypid":"3910","rngsubtype":"1184","rngmultitypid":"4534","rngcollation":"0","rngsubopc":"3127","rngcanonical":"-","rngsubdiff":"tstzrange_subdiff"} -// {"rngtypid":"3912","rngsubtype":"1082","rngmultitypid":"4535","rngcollation":"0","rngsubopc":"3122","rngcanonical":"daterange_canonical","rngsubdiff":"daterange_subdiff"} -// {"rngtypid":"3926","rngsubtype":"20","rngmultitypid":"4536","rngcollation":"0","rngsubopc":"3124","rngcanonical":"int8range_canonical","rngsubdiff":"int8range_subdiff"} - unsafe impl VTabCursor for PgRangeTableCursor<'_> { fn filter( &mut self, diff --git a/crates/corro-pg/src/vtab/pg_type.rs b/crates/corro-pg/src/vtab/pg_type.rs index 3340a18b..9bbecee7 100644 --- a/crates/corro-pg/src/vtab/pg_type.rs +++ b/crates/corro-pg/src/vtab/pg_type.rs @@ -115,7 +115,10 @@ impl PgType { Type::VARCHAR => -1, Type::FLOAT4 => 4, Type::FLOAT8 => 8, - _ => todo!(), + _ => { + // TODO: not default... + Default::default() + } } } fn typbyval(&self) -> bool { @@ -129,7 +132,10 @@ impl PgType { Type::VARCHAR => false, Type::FLOAT4 => true, Type::FLOAT8 => true, - _ => todo!(), + _ => { + // TODO: not default... + Default::default() + } } } fn typtype(&self) -> &'static str { @@ -146,26 +152,32 @@ impl PgType { Type::VARCHAR => "S", Type::FLOAT4 => "N", Type::FLOAT8 => "N", - _ => todo!(), + _ => { + // TODO: not default... + Default::default() + } } } fn typispreferred(&self) -> bool { - todo!() + // TODO: not default... + Default::default() } fn typisdefined(&self) -> bool { true } fn typdelim(&self) -> &'static str { - todo!() + // TODO: not default... + Default::default() } - fn typrelid(&self) -> &'static str { - todo!() + fn typrelid(&self) -> i64 { + 0 } fn typelem(&self) -> &'static str { "0" } fn typarray(&self) -> &'static str { - todo!() + // TODO: not default... + Default::default() } fn typinput(&self) -> String { format!("{}in", self.0.name()) @@ -180,19 +192,23 @@ impl PgType { format!("{}send", self.0.name()) } fn typmodin(&self) -> &'static str { - todo!() + // TODO: not default... + Default::default() } fn typmodout(&self) -> &'static str { - todo!() + // TODO: not default... + Default::default() } fn typanalyze(&self) -> &'static str { "-" } fn typalign(&self) -> &'static str { - todo!() + // TODO: not default... + Default::default() } fn typstorage(&self) -> &'static str { - todo!() + // TODO: not default... + Default::default() } fn typnotnull(&self) -> bool { false @@ -207,7 +223,8 @@ impl PgType { 0 } fn typcollation(&self) -> &'static str { - todo!() + // TODO: not default... + Default::default() } fn typdefaultbin(&self) -> rusqlite::types::Null { rusqlite::types::Null From 20441e384f25e6d785b63fb79ae0b13ab145243f Mon Sep 17 00:00:00 2001 From: Jerome Gravel-Niquet Date: Thu, 26 Oct 2023 14:55:36 -0400 Subject: [PATCH 06/12] clean up mess from failed experiment --- Cargo.lock | 32 -- Cargo.toml | 2 - crates/corro-agent/Cargo.toml | 1 - crates/corro-agent/src/api/mod.rs | 1 - crates/corro-agent/src/api/pg.rs | 1 - crates/corro-agent/src/api/public/mod.rs | 547 +---------------------- crates/corro-types/Cargo.toml | 3 - crates/corro-types/src/http.rs | 181 -------- crates/corro-types/src/lib.rs | 1 - 9 files changed, 7 insertions(+), 762 deletions(-) delete mode 100644 crates/corro-agent/src/api/pg.rs delete mode 100644 crates/corro-types/src/http.rs diff --git a/Cargo.lock b/Cargo.lock index 1c487b15..8ebfa659 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -759,7 +759,6 @@ dependencies = [ "quoted-string", "rand", "rangemap", - "rmp-serde", "rusqlite", "rustls", "rustls-pemfile", @@ -923,18 +922,15 @@ dependencies = [ "fallible-iterator", "foca", "futures", - "hyper", "indexmap", "itertools", "metrics", "once_cell", "opentelemetry", "parking_lot", - "pin-project-lite", "rand", "rangemap", "rcgen", - "rmp-serde", "rusqlite", "serde", "serde_json", @@ -2624,12 +2620,6 @@ dependencies = [ "windows-sys 0.45.0", ] -[[package]] -name = "paste" -version = "1.0.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" - [[package]] name = "pathdiff" version = "0.2.1" @@ -3169,28 +3159,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "rmp" -version = "0.8.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f9860a6cc38ed1da53456442089b4dfa35e7cedaa326df63017af88385e6b20" -dependencies = [ - "byteorder", - "num-traits", - "paste", -] - -[[package]] -name = "rmp-serde" -version = "1.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bffea85eea980d8a74453e5d02a8d93028f3c34725de143085a844ebe953258a" -dependencies = [ - "byteorder", - "rmp", - "serde", -] - [[package]] name = "rusqlite" version = "0.29.0" diff --git a/Cargo.toml b/Cargo.toml index b3ae1e89..0e77143a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,6 @@ hyper = { version = "0.14.26", features = ["h2", "http1", "http2", "server", "tc hyper-rustls = { version = "0.24.0", features = ["http2"] } indexmap = { version = "1.9.3", features = ["serde"] } itertools = { version = "0.10.5" } -rmp-serde = { version = "1.1.2" } metrics = "0.21.0" metrics-exporter-prometheus = "0.12.0" once_cell = "1.17.1" @@ -42,7 +41,6 @@ opentelemetry-otlp = { version = "0.13.0" } opentelemetry-semantic-conventions = { version = "0.12.0" } parking_lot = { version = "0.12.1" } pin-project-lite = "0.2.9" -polonius-the-crab = { version = "0.4.1" } quinn = "0.10.2" quinn-proto = "0.10.5" quinn-plaintext = "0.1.0" diff --git a/crates/corro-agent/Cargo.toml b/crates/corro-agent/Cargo.toml index 76479149..eb70a346 100644 --- a/crates/corro-agent/Cargo.toml +++ b/crates/corro-agent/Cargo.toml @@ -28,7 +28,6 @@ quinn-plaintext = { workspace = true } quoted-string = { workspace = true } rand = { workspace = true } rangemap = { workspace = true } -rmp-serde = { workspace = true } rusqlite = { workspace = true } rustls = { workspace = true } rustls-pemfile = "*" diff --git a/crates/corro-agent/src/api/mod.rs b/crates/corro-agent/src/api/mod.rs index cffa154c..dbabc9f7 100644 --- a/crates/corro-agent/src/api/mod.rs +++ b/crates/corro-agent/src/api/mod.rs @@ -1,3 +1,2 @@ pub mod peer; -pub mod pg; pub mod public; diff --git a/crates/corro-agent/src/api/pg.rs b/crates/corro-agent/src/api/pg.rs deleted file mode 100644 index 8b137891..00000000 --- a/crates/corro-agent/src/api/pg.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/crates/corro-agent/src/api/public/mod.rs b/crates/corro-agent/src/api/public/mod.rs index e849519a..16d60d6b 100644 --- a/crates/corro-agent/src/api/public/mod.rs +++ b/crates/corro-agent/src/api/public/mod.rs @@ -1,43 +1,33 @@ use std::{ - collections::HashMap, iter::Peekable, - mem::forget, - ops::{Deref, DerefMut, RangeInclusive}, + ops::RangeInclusive, time::{Duration, Instant}, }; -use axum::{extract, response::IntoResponse, Extension}; -use bytes::{BufMut, Bytes, BytesMut}; +use axum::{response::IntoResponse, Extension}; +use bytes::{BufMut, BytesMut}; use compact_str::ToCompactString; use corro_types::{ agent::{Agent, ChangeError, KnownDbVersion}, api::{row_to_change, ExecResponse, ExecResult, QueryEvent, Statement}, broadcast::{ChangeV1, Changeset, Timestamp}, change::SqliteValue, - http::IoBodyStream, schema::{apply_schema, parse_sql}, sqlite::SqlitePoolError, }; -use futures::StreamExt; use hyper::StatusCode; use itertools::Itertools; use metrics::counter; -use rusqlite::{named_params, params_from_iter, Connection, ToSql, Transaction}; -use serde::{Deserialize, Serialize}; +use rusqlite::{named_params, params_from_iter, ToSql, Transaction}; use spawn::spawn_counted; use tokio::{ sync::{ - mpsc::{self, channel, error::SendError, Receiver, Sender}, + mpsc::{self, channel}, oneshot, }, task::block_in_place, }; -use tokio_util::{ - codec::{Encoder, FramedRead, LengthDelimitedCodec}, - io::StreamReader, - sync::CancellationToken, -}; -use tracing::{debug, error, info, trace, Instrument}; +use tracing::{debug, error, info, trace}; use corro_types::{ broadcast::{BroadcastInput, BroadcastV1}, @@ -156,489 +146,6 @@ where } } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[serde(rename_all = "snake_case")] -#[repr(u8)] -pub enum Stmt { - Prepare(String), - Drop(u32), - Reset(u32), - Columns(u32), - - Execute(u32, Vec), - Query(u32, Vec), - - Next(u32), - - Begin, - Commit, - Rollback, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[serde(rename_all = "snake_case")] -#[repr(u8)] -pub enum SqliteResult { - Ok, - Error(String), - - Statement { - id: u32, - params_count: usize, - }, - - Execute { - rows_affected: usize, - last_insert_rowid: i64, - }, - - Columns(Vec), - - // None represents the end of a statement's rows - Row(Option>), -} - -#[derive(Debug, thiserror::Error)] -enum HandleConnError { - #[error(transparent)] - Rusqlite(#[from] rusqlite::Error), - #[error("events channel closed")] - EventsChannelClosed, -} - -impl From> for HandleConnError { - fn from(_: SendError) -> Self { - HandleConnError::EventsChannelClosed - } -} - -#[derive(Clone, Debug)] -struct IncrMap { - map: HashMap, - last: u32, -} - -impl IncrMap { - pub fn insert(&mut self, v: V) -> u32 { - self.last += 1; - self.map.insert(self.last, v); - self.last - } -} - -impl Default for IncrMap { - fn default() -> Self { - Self { - map: Default::default(), - last: Default::default(), - } - } -} - -impl Deref for IncrMap { - type Target = HashMap; - - fn deref(&self) -> &Self::Target { - &self.map - } -} - -impl DerefMut for IncrMap { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.map - } -} - -#[derive(Debug, Default)] -struct ConnCache<'conn> { - prepared: IncrMap>, - cells: Vec, -} - -#[derive(Debug, thiserror::Error)] -enum StmtError { - #[error(transparent)] - Rusqlite(#[from] rusqlite::Error), - #[error("statement not found: {id}")] - StatementNotFound { id: u32 }, -} - -fn handle_stmt<'conn>( - agent: &Agent, - conn: &'conn Connection, - cache: &mut ConnCache<'conn>, - stmt: Stmt, -) -> Result { - match stmt { - Stmt::Prepare(sql) => { - let prepped = conn.prepare(&sql)?; - let params_count = prepped.parameter_count(); - let id = cache.prepared.insert(prepped); - Ok(SqliteResult::Statement { id, params_count }) - } - Stmt::Columns(id) => { - let prepped = cache - .prepared - .get(&id) - .ok_or(StmtError::StatementNotFound { id })?; - Ok(SqliteResult::Columns( - prepped - .column_names() - .into_iter() - .map(|name| name.to_string()) - .collect(), - )) - } - Stmt::Execute(id, params) => { - let prepped = cache - .prepared - .get_mut(&id) - .ok_or(StmtError::StatementNotFound { id })?; - let rows_affected = prepped.execute(params_from_iter(params))?; - Ok(SqliteResult::Execute { - rows_affected, - last_insert_rowid: conn.last_insert_rowid(), - }) - } - Stmt::Query(id, params) => { - let prepped = cache - .prepared - .get_mut(&id) - .ok_or(StmtError::StatementNotFound { id })?; - - for (i, param) in params.into_iter().enumerate() { - prepped.raw_bind_parameter(i + 1, param)?; - } - - Ok(SqliteResult::Ok) - } - Stmt::Next(id) => { - let prepped = cache - .prepared - .get_mut(&id) - .ok_or(StmtError::StatementNotFound { id })?; - - // creates an interator for already-bound statements - let mut rows = prepped.raw_query(); - - let res = match rows.next()? { - Some(row) => { - let col_count = row.as_ref().column_count(); - cache.cells.clear(); - for idx in 0..col_count { - let v = row.get::<_, SqliteValue>(idx)?; - cache.cells.push(v); - } - Ok(SqliteResult::Row(Some(cache.cells.drain(..).collect_vec()))) - } - None => Ok(SqliteResult::Row(None)), - }; - - // prevent running Drop so it doesn't reset everything... - forget(rows); - - res - } - Stmt::Begin => { - conn.execute_batch("BEGIN")?; - Ok(SqliteResult::Ok) - } - Stmt::Commit => { - handle_commit(agent, conn)?; - Ok(SqliteResult::Ok) - } - Stmt::Rollback => { - conn.execute_batch("ROLLBACK")?; - Ok(SqliteResult::Ok) - } - Stmt::Drop(id) => { - cache.prepared.remove(&id); - Ok(SqliteResult::Ok) - } - Stmt::Reset(id) => { - let prepped = cache - .prepared - .get_mut(&id) - .ok_or(StmtError::StatementNotFound { id })?; - - // not sure how to reset a statement otherwise.. - let rows = prepped.raw_query(); - drop(rows); - - Ok(SqliteResult::Ok) - } - } -} - -fn handle_interactive( - agent: &Agent, - mut queries: Receiver, - events: Sender, -) -> Result<(), HandleConnError> { - let conn = match agent.pool().client_dedicated() { - Ok(conn) => conn, - Err(e) => { - return events - .blocking_send(SqliteResult::Error(e.to_string())) - .map_err(HandleConnError::from); - } - }; - - let mut cache = ConnCache::default(); - - while let Some(stmt) = queries.blocking_recv() { - match handle_stmt(agent, &conn, &mut cache, stmt) { - Ok(res) => { - events.blocking_send(res)?; - } - Err(e) => { - events.blocking_send(SqliteResult::Error(e.to_string()))?; - } - } - } - - Ok(()) -} - -fn handle_commit(agent: &Agent, conn: &Connection) -> rusqlite::Result<()> { - let actor_id = agent.actor_id(); - - let ts = Timestamp::from(agent.clock().new_timestamp()); - - let db_version: i64 = conn - .prepare_cached("SELECT crsql_next_db_version()")? - .query_row((), |row| row.get(0))?; - - let has_changes: bool = conn - .prepare_cached( - "SELECT EXISTS(SELECT 1 FROM crsql_changes WHERE site_id IS NULL AND db_version = ?);", - )? - .query_row([db_version], |row| row.get(0))?; - - if !has_changes { - conn.execute_batch("COMMIT")?; - return Ok(()); - } - - let booked = { - agent - .bookie() - .blocking_write("handle_write_tx(for_actor)") - .for_actor(actor_id) - }; - - let last_seq: i64 = conn - .prepare_cached( - "SELECT MAX(seq) FROM crsql_changes WHERE site_id IS NULL AND db_version = ?", - )? - .query_row([db_version], |row| row.get(0))?; - - let mut book_writer = booked.blocking_write("handle_write_tx(book_writer)"); - - let last_version = book_writer.last().unwrap_or_default(); - trace!("last_version: {last_version}"); - let version = last_version + 1; - trace!("version: {version}"); - - conn.prepare_cached( - r#" - INSERT INTO __corro_bookkeeping (actor_id, start_version, db_version, last_seq, ts) - VALUES (:actor_id, :start_version, :db_version, :last_seq, :ts); - "#, - )? - .execute(named_params! { - ":actor_id": actor_id, - ":start_version": version, - ":db_version": db_version, - ":last_seq": last_seq, - ":ts": ts - })?; - - debug!(%actor_id, %version, %db_version, "inserted local bookkeeping row!"); - - conn.execute_batch("COMMIT")?; - - trace!("committed tx, db_version: {db_version}, last_seq: {last_seq:?}"); - - book_writer.insert( - version, - KnownDbVersion::Current { - db_version, - last_seq, - ts, - }, - ); - - let agent = agent.clone(); - - spawn_counted(async move { - let conn = agent.pool().read().await?; - - block_in_place(|| { - // TODO: make this more generic so both sync and local changes can use it. - let mut prepped = conn.prepare_cached(r#" - SELECT "table", pk, cid, val, col_version, db_version, seq, COALESCE(site_id, crsql_site_id()), cl - FROM crsql_changes - WHERE site_id IS NULL - AND db_version = ? - ORDER BY seq ASC - "#)?; - let rows = prepped.query_map([db_version], row_to_change)?; - let chunked = ChunkedChanges::new(rows, 0, last_seq, MAX_CHANGES_BYTE_SIZE); - for changes_seqs in chunked { - match changes_seqs { - Ok((changes, seqs)) => { - for (table_name, count) in changes.iter().counts_by(|change| &change.table) - { - counter!("corro.changes.committed", count as u64, "table" => table_name.to_string(), "source" => "local"); - } - process_subs(&agent, &changes); - - trace!("broadcasting changes: {changes:?} for seq: {seqs:?}"); - - let tx_bcast = agent.tx_bcast().clone(); - tokio::spawn(async move { - if let Err(e) = tx_bcast - .send(BroadcastInput::AddBroadcast(BroadcastV1::Change( - ChangeV1 { - actor_id, - changeset: Changeset::Full { - version, - changes, - seqs, - last_seq, - ts, - }, - }, - ))) - .await - { - error!("could not send change message for broadcast: {e}"); - } - }); - } - Err(e) => { - error!("could not process crsql change (db_version: {db_version}) for broadcast: {e}"); - break; - } - } - } - Ok::<_, rusqlite::Error>(()) - })?; - Ok::<_, eyre::Report>(()) - }); - Ok::<_, rusqlite::Error>(()) -} - -#[tracing::instrument(skip_all)] -pub async fn api_v1_begins( - // axum::extract::RawQuery(raw_query): axum::extract::RawQuery, - Extension(agent): Extension, - req_body: extract::RawBody, -) -> impl IntoResponse { - let (mut body_tx, body) = hyper::Body::channel(); - - let req_body = IoBodyStream { body: req_body.0 }; - - let (queries_tx, queries_rx) = channel(512); - let (events_tx, mut events_rx) = channel(512); - let cancel = CancellationToken::new(); - - tokio::spawn({ - let cancel = cancel.clone(); - let events_tx = events_tx.clone(); - async move { - let _drop_guard = cancel.drop_guard(); - - let mut req_reader = - FramedRead::new(StreamReader::new(req_body), LengthDelimitedCodec::default()); - - while let Some(buf_res) = req_reader.next().await { - match buf_res { - Ok(buf) => match rmp_serde::from_slice(&buf) { - Ok(req) => { - if let Err(e) = queries_tx.send(req).await { - error!("could not send request into channel: {e}"); - if let Err(e) = events_tx - .send(SqliteResult::Error("request channel closed".into())) - .await - { - error!("could not send error event: {e}"); - } - return; - } - } - Err(e) => { - error!("could not parse message: {e}"); - if let Err(e) = events_tx - .send(SqliteResult::Error("request channel closed".into())) - .await - { - error!("could not send error event: {e}"); - } - } - }, - Err(e) => { - error!("could not read buffer from request body: {e}"); - break; - } - } - } - } - .in_current_span() - }); - - // probably a better way to do this... - spawn_counted( - async move { block_in_place(|| handle_interactive(&agent, queries_rx, events_tx)) } - .in_current_span(), - ); - - tokio::spawn(async move { - let mut ser_buf = BytesMut::new(); - let mut encode_buf = BytesMut::new(); - let mut codec = LengthDelimitedCodec::default(); - - while let Some(event) = events_rx.recv().await { - match rmp_serde::encode::write(&mut (&mut ser_buf).writer(), &event) { - Ok(_) => match codec.encode(ser_buf.split().freeze(), &mut encode_buf) { - Ok(_) => { - if let Err(e) = body_tx.send_data(encode_buf.split().freeze()).await { - error!("could not send tx event to response body: {e}"); - return; - } - } - Err(e) => { - error!("could not encode event: {e}"); - if let Err(e) = body_tx - .send_data(Bytes::from(r#"{"error": "could not encode event"}"#)) - .await - { - error!("could not send encoding error to body: {e}"); - return; - } - } - }, - Err(e) => { - error!("could not serialize event: {e}"); - if let Err(e) = body_tx - .send_data(Bytes::from(r#"{"error": "could not serialize event"}"#)) - .await - { - error!("could not send serialize error to body: {e}"); - return; - } - } - } - } - }); - - hyper::Response::builder() - .status(hyper::StatusCode::SWITCHING_PROTOCOLS) - .body(body) - .unwrap() -} - const MAX_CHANGES_BYTE_SIZE: usize = 8 * 1024; pub async fn make_broadcastable_changes( @@ -1139,8 +646,7 @@ pub async fn api_v1_queries( async fn execute_schema(agent: &Agent, statements: Vec) -> eyre::Result<()> { let new_sql: String = statements.join(";"); - let mut partial_schema = parse_sql(&new_sql)?; - partial_schema.constrain()?; + let partial_schema = parse_sql(&new_sql)?; let mut conn = agent.pool().write_priority().await?; @@ -1222,11 +728,9 @@ pub async fn api_v1_db_schema( #[cfg(test)] mod tests { use bytes::Bytes; - use corro_tests::launch_test_agent; use corro_types::{api::RowId, config::Config, schema::SqliteType}; use futures::Stream; use http_body::{combinators::UnsyncBoxBody, Body}; - use spawn::wait_for_all_pending_handles; use tokio::sync::mpsc::error::TryRecvError; use tokio_util::codec::{Decoder, LinesCodec}; use tripwire::Tripwire; @@ -1758,41 +1262,4 @@ mod tests { assert_eq!(chunker.next(), None); } - - #[tokio::test(flavor = "multi_thread")] - async fn test_interactive() -> eyre::Result<()> { - _ = tracing_subscriber::fmt::try_init(); - let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple(); - let ta = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?; - - let (q_tx, q_rx) = channel(1); - let (e_tx, mut e_rx) = channel(1); - - spawn_counted(async move { block_in_place(|| handle_interactive(&ta.agent, q_rx, e_tx)) }); - - q_tx.send(Stmt::Prepare("SELECT 123".into())).await?; - let e = e_rx.recv().await.unwrap(); - println!("e: {e:?}"); - - q_tx.send(Stmt::Query(1, vec![])).await?; - let e = e_rx.recv().await.unwrap(); - println!("e: {e:?}"); - - q_tx.send(Stmt::Next(1)).await?; - let e = e_rx.recv().await.unwrap(); - assert_eq!(e, SqliteResult::Row(Some(vec![SqliteValue::Integer(123)]))); - - q_tx.send(Stmt::Next(1)).await?; - let e = e_rx.recv().await.unwrap(); - assert_eq!(e, SqliteResult::Row(None)); - - drop(q_tx); - drop(e_rx); - - tripwire_tx.send(()).await.ok(); - tripwire_worker.await; - wait_for_all_pending_handles().await; - - Ok(()) - } } diff --git a/crates/corro-types/Cargo.toml b/crates/corro-types/Cargo.toml index 5c3649cf..cd9bec9a 100644 --- a/crates/corro-types/Cargo.toml +++ b/crates/corro-types/Cargo.toml @@ -20,15 +20,12 @@ enquote = { workspace = true } fallible-iterator = { workspace = true } foca = { workspace = true } futures = { workspace = true } -hyper = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } -rmp-serde = { workspace = true } metrics = { workspace = true } once_cell = { workspace = true } opentelemetry = { workspace = true } parking_lot = { workspace = true } -pin-project-lite = { workspace = true } rand = { workspace = true } rangemap = { workspace = true } rcgen = { workspace = true } diff --git a/crates/corro-types/src/http.rs b/crates/corro-types/src/http.rs deleted file mode 100644 index 629ee594..00000000 --- a/crates/corro-types/src/http.rs +++ /dev/null @@ -1,181 +0,0 @@ -use std::{ - error::Error, - io, - pin::Pin, - task::{Context, Poll}, -}; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use futures::{ready, Stream}; -use hyper::Body; -use pin_project_lite::pin_project; -use tokio_util::codec::{Decoder, Encoder, LinesCodecError}; - -pin_project! { - pub struct IoBodyStream { - #[pin] - pub body: Body - } -} - -impl Stream for IoBodyStream { - type Item = io::Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - let res = ready!(this.body.poll_next(cx)); - match res { - Some(Ok(b)) => Poll::Ready(Some(Ok(b))), - Some(Err(e)) => { - let io_err = match e - .source() - .and_then(|source| source.downcast_ref::()) - { - Some(io_err) => io::Error::from(io_err.kind()), - None => io::Error::new(io::ErrorKind::Other, e), - }; - Poll::Ready(Some(Err(io_err))) - } - None => Poll::Ready(None), - } - } -} - -// type IoBodyStreamReader = StreamReader; -// type FramedBody = FramedRead; - -pub struct LinesBytesCodec { - // Stored index of the next index to examine for a `\n` character. - // This is used to optimize searching. - // For example, if `decode` was called with `abc`, it would hold `3`, - // because that is the next index to examine. - // The next time `decode` is called with `abcde\n`, the method will - // only look at `de\n` before returning. - next_index: usize, - - /// The maximum length for a given line. If `usize::MAX`, lines will be - /// read until a `\n` character is reached. - max_length: usize, - - /// Are we currently discarding the remainder of a line which was over - /// the length limit? - is_discarding: bool, -} - -impl Default for LinesBytesCodec { - /// Returns a `LinesBytesCodec` for splitting up data into lines. - /// - /// # Note - /// - /// The returned `LinesBytesCodec` will not have an upper bound on the length - /// of a buffered line. See the documentation for [`new_with_max_length`] - /// for information on why this could be a potential security risk. - /// - /// [`new_with_max_length`]: crate::codec::LinesBytesCodec::default_with_max_length() - fn default() -> Self { - LinesBytesCodec { - next_index: 0, - max_length: usize::MAX, - is_discarding: false, - } - } -} - -impl Decoder for LinesBytesCodec { - type Item = BytesMut; - type Error = LinesCodecError; - - fn decode(&mut self, buf: &mut BytesMut) -> Result, LinesCodecError> { - loop { - // Determine how far into the buffer we'll search for a newline. If - // there's no max_length set, we'll read to the end of the buffer. - let read_to = std::cmp::min(self.max_length.saturating_add(1), buf.len()); - - let newline_offset = buf[self.next_index..read_to] - .iter() - .position(|b| *b == b'\n'); - - match (self.is_discarding, newline_offset) { - (true, Some(offset)) => { - // If we found a newline, discard up to that offset and - // then stop discarding. On the next iteration, we'll try - // to read a line normally. - buf.advance(offset + self.next_index + 1); - self.is_discarding = false; - self.next_index = 0; - } - (true, None) => { - // Otherwise, we didn't find a newline, so we'll discard - // everything we read. On the next iteration, we'll continue - // discarding up to max_len bytes unless we find a newline. - buf.advance(read_to); - self.next_index = 0; - if buf.is_empty() { - return Ok(None); - } - } - (false, Some(offset)) => { - // Found a line! - let newline_index = offset + self.next_index; - self.next_index = 0; - let mut line = buf.split_to(newline_index + 1); - line.truncate(line.len() - 1); - without_carriage_return(&mut line); - return Ok(Some(line)); - } - (false, None) if buf.len() > self.max_length => { - // Reached the maximum length without finding a - // newline, return an error and start discarding on the - // next call. - self.is_discarding = true; - return Err(LinesCodecError::MaxLineLengthExceeded); - } - (false, None) => { - // We didn't find a line or reach the length limit, so the next - // call will resume searching at the current offset. - self.next_index = read_to; - return Ok(None); - } - } - } - } - - fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, LinesCodecError> { - Ok(match self.decode(buf)? { - Some(frame) => Some(frame), - None => { - // No terminating newline - return remaining data, if any - if buf.is_empty() || buf == &b"\r"[..] { - None - } else { - let mut line = buf.split_to(buf.len()); - line.truncate(line.len() - 1); - without_carriage_return(&mut line); - self.next_index = 0; - Some(line) - } - } - }) - } -} - -fn without_carriage_return(s: &mut BytesMut) { - if let Some(&b'\r') = s.last() { - s.truncate(s.len() - 1); - } -} - -impl Encoder for LinesBytesCodec -where - T: AsRef<[u8]>, -{ - type Error = LinesCodecError; - - fn encode(&mut self, line: T, buf: &mut BytesMut) -> Result<(), LinesCodecError> { - let line = line.as_ref(); - buf.reserve(line.len() + 1); - buf.put(line); - buf.put_u8(b'\n'); - Ok(()) - } -} diff --git a/crates/corro-types/src/lib.rs b/crates/corro-types/src/lib.rs index 778962fd..060fce58 100644 --- a/crates/corro-types/src/lib.rs +++ b/crates/corro-types/src/lib.rs @@ -5,7 +5,6 @@ pub mod api; pub mod broadcast; pub mod change; pub mod config; -pub mod http; pub mod members; pub mod pubsub; pub mod schema; From 1283b2657291fd0412d67ba1a03320cf9d3b1162 Mon Sep 17 00:00:00 2001 From: Jerome Gravel-Niquet Date: Thu, 26 Oct 2023 14:56:39 -0400 Subject: [PATCH 07/12] remove unused files from old implementation --- crates/corro-pg/src/lib.rs | 6 +- crates/corro-pg/src/proto.rs | 650 ------------------------------- crates/corro-pg/src/proto_ext.rs | 163 -------- 3 files changed, 2 insertions(+), 817 deletions(-) delete mode 100644 crates/corro-pg/src/proto.rs delete mode 100644 crates/corro-pg/src/proto_ext.rs diff --git a/crates/corro-pg/src/lib.rs b/crates/corro-pg/src/lib.rs index 5a69c61b..8d81a8cd 100644 --- a/crates/corro-pg/src/lib.rs +++ b/crates/corro-pg/src/lib.rs @@ -1,5 +1,3 @@ -pub mod proto; -pub mod proto_ext; pub mod sql_state; mod vtab; @@ -40,8 +38,8 @@ use pgwire::{ use postgres_types::{FromSql, Type}; use rusqlite::{named_params, types::ValueRef, vtab::eponymous_only_module, Connection, Statement}; use sqlite3_parser::ast::{ - As, Cmd, CreateTableBody, Expr, FromClause, Id, InsertBody, Literal, Name, OneSelect, - ResultColumn, Select, SelectTable, Stmt, + As, Cmd, CreateTableBody, Expr, FromClause, Id, InsertBody, Name, OneSelect, ResultColumn, + Select, SelectTable, Stmt, }; use tokio::{ io::{AsyncReadExt, AsyncWriteExt, ReadBuf}, diff --git a/crates/corro-pg/src/proto.rs b/crates/corro-pg/src/proto.rs deleted file mode 100644 index 29cfa126..00000000 --- a/crates/corro-pg/src/proto.rs +++ /dev/null @@ -1,650 +0,0 @@ -//! Contains types that represent the core Postgres wire protocol. - -// this module requires a lot more work to document -// may want to build this automatically from Postgres docs if possible -#![allow(missing_docs)] - -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use std::convert::TryFrom; -use std::fmt::Display; -use std::mem::size_of; -use std::{collections::HashMap, convert::TryInto}; -use tokio_util::codec::{Decoder, Encoder}; - -macro_rules! data_types { - ($($name:ident = $oid:expr, $size: expr)*) => { - #[derive(Debug, Copy, Clone)] - /// Describes a Postgres data type. - pub enum DataTypeOid { - $( - #[allow(missing_docs)] - $name, - )* - /// A type which is not known to this crate. - Unknown(u32), - } - - impl DataTypeOid { - /// Fetch the size in bytes for this data type. - /// Variably-sized types return -1. - pub fn size_bytes(&self) -> i16 { - match self { - $( - Self::$name => $size, - )* - Self::Unknown(_) => unimplemented!(), - } - } - } - - impl From for DataTypeOid { - fn from(value: u32) -> Self { - match value { - $( - $oid => Self::$name, - )* - other => Self::Unknown(other), - } - } - } - - impl From for u32 { - fn from(value: DataTypeOid) -> Self { - match value { - $( - DataTypeOid::$name => $oid, - )* - DataTypeOid::Unknown(other) => other, - } - } - } - }; -} - -// For oid see: -// https://github.com/sfackler/rust-postgres/blob/master/postgres-types/src/type_gen.rs -data_types! { - Unspecified = 0, 0 - - Bool = 16, 1 - - Int2 = 21, 2 - Int4 = 23, 4 - Int8 = 20, 8 - - Float4 = 700, 4 - Float8 = 701, 8 - - Date = 1082, 4 - Timestamp = 1114, 8 - - Text = 25, -1 -} - -/// Describes how to format a given value or set of values. -#[derive(Debug, Copy, Clone)] -pub enum FormatCode { - /// Use the stable text representation. - Text = 0, - /// Use the less-stable binary representation. - Binary = 1, -} - -impl TryFrom for FormatCode { - type Error = ProtocolError; - - fn try_from(value: i16) -> Result { - match value { - 0 => Ok(FormatCode::Text), - 1 => Ok(FormatCode::Binary), - other => Err(ProtocolError::InvalidFormatCode(other)), - } - } -} - -#[derive(Debug)] -pub struct Startup { - pub requested_protocol_version: (i16, i16), - pub parameters: HashMap, -} - -#[derive(Debug)] -pub enum Describe { - Portal(String), - PreparedStatement(String), -} - -#[derive(Debug)] -pub struct Parse { - pub prepared_statement_name: String, - pub query: String, - pub parameter_types: Vec, -} - -#[derive(Debug)] -pub enum BindFormat { - All(FormatCode), - PerColumn(Vec), -} - -#[derive(Debug)] -pub struct Bind { - pub portal: String, - pub prepared_statement_name: String, - pub parameter_values: Vec, - pub result_format: BindFormat, -} - -#[derive(Debug)] -pub enum BindValue { - Text(String), - Binary(Bytes), -} - -#[derive(Debug)] -pub struct Execute { - pub portal: String, - pub max_rows: Option, -} - -#[derive(Debug)] -pub enum FrontendMessage { - SSLRequest, // for SSL negotiation - Startup(Startup), - Parse(Parse), - Describe(Describe), - Bind(Bind), - Sync, - Execute(Execute), - Query(String), - Terminate, -} - -pub trait BackendMessage: std::fmt::Debug { - const TAG: u8; - - fn encode(&self, dst: &mut BytesMut); -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum SqlState { - SuccessfulCompletion, - FeatureNotSupported, - InvalidCursorName, - ConnectionException, - InvalidSQLStatementName, - DataException, - ProtocolViolation, - SyntaxError, - InvalidDatetimeFormat, -} - -impl SqlState { - pub fn code(&self) -> &str { - match self { - Self::SuccessfulCompletion => "00000", - Self::FeatureNotSupported => "0A000", - Self::InvalidCursorName => "34000", - Self::ConnectionException => "08000", - Self::InvalidSQLStatementName => "26000", - Self::DataException => "22000", - Self::ProtocolViolation => "08P01", - Self::SyntaxError => "42601", - Self::InvalidDatetimeFormat => "22007", - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Severity { - Error, - Fatal, -} - -impl Severity { - pub fn code(&self) -> &str { - match self { - Self::Fatal => "FATAL", - Self::Error => "ERROR", - } - } -} - -#[derive(thiserror::Error, Debug, Clone)] -pub struct ErrorResponse { - pub sql_state: SqlState, - pub severity: Severity, - pub message: String, -} - -impl ErrorResponse { - pub fn new(sql_state: SqlState, severity: Severity, message: impl Into) -> Self { - ErrorResponse { - sql_state, - severity, - message: message.into(), - } - } - - pub fn error(sql_state: SqlState, message: impl Into) -> Self { - Self::new(sql_state, Severity::Error, message) - } - - pub fn fatal(sql_state: SqlState, message: impl Into) -> Self { - Self::new(sql_state, Severity::Error, message) - } -} - -impl Display for ErrorResponse { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "error") - } -} - -impl BackendMessage for ErrorResponse { - const TAG: u8 = b'E'; - - fn encode(&self, dst: &mut BytesMut) { - dst.put_u8(b'C'); - dst.put_slice(self.sql_state.code().as_bytes()); - dst.put_u8(0); - dst.put_u8(b'S'); - dst.put_slice(self.severity.code().as_bytes()); - dst.put_u8(0); - dst.put_u8(b'M'); - dst.put_slice(self.message.as_bytes()); - dst.put_u8(0); - - dst.put_u8(0); // tag - } -} - -#[derive(Debug)] -pub struct ParameterDescription {} - -impl BackendMessage for ParameterDescription { - const TAG: u8 = b't'; - - fn encode(&self, dst: &mut BytesMut) { - dst.put_i16(0); - } -} - -#[derive(Debug, Clone)] -pub struct FieldDescription { - pub name: String, - pub data_type: DataTypeOid, -} - -#[derive(Debug, Clone)] -pub struct RowDescription { - pub fields: Vec, - pub format_code: FormatCode, -} - -impl BackendMessage for RowDescription { - const TAG: u8 = b'T'; - - fn encode(&self, dst: &mut BytesMut) { - dst.put_i16(self.fields.len() as i16); - for field in &self.fields { - dst.put_slice(field.name.as_bytes()); - dst.put_u8(0); - dst.put_i32(0); // table oid - dst.put_i16(0); // column attr number - dst.put_u32(field.data_type.into()); - dst.put_i16(field.data_type.size_bytes()); - dst.put_i32(-1); // data type modifier - dst.put_i16(self.format_code as i16); - } - } -} - -#[derive(Debug)] -pub struct AuthenticationOk; - -impl BackendMessage for AuthenticationOk { - const TAG: u8 = b'R'; - - fn encode(&self, dst: &mut BytesMut) { - dst.put_i32(0); - } -} - -#[derive(Debug)] -pub struct ReadyForQuery; - -impl BackendMessage for ReadyForQuery { - const TAG: u8 = b'Z'; - - fn encode(&self, dst: &mut BytesMut) { - dst.put_u8(b'I'); - } -} - -#[derive(Debug)] -pub struct ParseComplete; - -impl BackendMessage for ParseComplete { - const TAG: u8 = b'1'; - - fn encode(&self, _dst: &mut BytesMut) {} -} - -#[derive(Debug)] -pub struct BindComplete; - -impl BackendMessage for BindComplete { - const TAG: u8 = b'2'; - - fn encode(&self, _dst: &mut BytesMut) {} -} - -#[derive(Debug)] -pub struct NoData; - -impl BackendMessage for NoData { - const TAG: u8 = b'n'; - - fn encode(&self, _dst: &mut BytesMut) {} -} - -#[derive(Debug)] -pub struct EmptyQueryResponse; - -impl BackendMessage for EmptyQueryResponse { - const TAG: u8 = b'I'; - - fn encode(&self, _dst: &mut BytesMut) {} -} - -#[derive(Debug)] -pub struct CommandComplete { - pub command_tag: String, -} - -impl BackendMessage for CommandComplete { - const TAG: u8 = b'C'; - - fn encode(&self, dst: &mut BytesMut) { - dst.put_slice(self.command_tag.as_bytes()); - dst.put_u8(0); - } -} - -#[derive(Debug)] -pub struct ParameterStatus { - name: String, - value: String, -} - -impl BackendMessage for ParameterStatus { - const TAG: u8 = b'S'; - - fn encode(&self, dst: &mut BytesMut) { - dst.put_slice(self.name.as_bytes()); - dst.put_u8(0); - dst.put_slice(self.value.as_bytes()); - dst.put_u8(0); - } -} - -impl ParameterStatus { - pub fn new(name: impl Into, value: impl Into) -> Self { - Self { - name: name.into(), - value: value.into(), - } - } -} - -#[derive(Default, Debug)] -pub struct ConnectionCodec { - // most state tracking is handled at a higher level - // however, the actual wire format uses a different header for startup vs normal messages - // so we need to be able to differentiate inside the decoder - startup_received: bool, -} - -impl ConnectionCodec { - pub fn new() -> Self { - Self { - startup_received: false, - } - } -} - -#[derive(thiserror::Error, Debug)] -pub enum ProtocolError { - #[error("io error: {0}")] - Io(#[from] std::io::Error), - #[error("utf8 error: {0}")] - Utf8(#[from] std::string::FromUtf8Error), - #[error("parsing error")] - ParserError, - #[error("invalid message type: {0}")] - InvalidMessageType(u8), - #[error("invalid format code: {0}")] - InvalidFormatCode(i16), -} - -// length prefix, two version components -const STARTUP_HEADER_SIZE: usize = size_of::() + (size_of::() * 2); -// message tag, length prefix -const MESSAGE_HEADER_SIZE: usize = size_of::() + size_of::(); - -impl Decoder for ConnectionCodec { - type Item = FrontendMessage; - type Error = ProtocolError; - - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - if !self.startup_received { - if src.len() < STARTUP_HEADER_SIZE { - return Ok(None); - } - - let mut header_buf = src.clone(); - let message_len = header_buf.get_i32() as usize; - let protocol_version_major = header_buf.get_i16(); - let protocol_version_minor = header_buf.get_i16(); - - if protocol_version_major == 1234i16 && protocol_version_minor == 5679i16 { - src.advance(STARTUP_HEADER_SIZE); - return Ok(Some(FrontendMessage::SSLRequest)); - } - - if src.len() < message_len { - src.reserve(message_len - src.len()); - return Ok(None); - } - - src.advance(STARTUP_HEADER_SIZE); - - let mut parameters = HashMap::new(); - - let mut param_str_start_pos = 0; - let mut current_key = None; - for (i, &blah) in src.iter().enumerate() { - if blah == 0 { - let string_value = String::from_utf8(src[param_str_start_pos..i].to_owned())?; - param_str_start_pos = i + 1; - - current_key = match current_key { - Some(key) => { - parameters.insert(key, string_value); - None - } - None => Some(string_value), - } - } - } - - src.advance(message_len - STARTUP_HEADER_SIZE); - - self.startup_received = true; - return Ok(Some(FrontendMessage::Startup(Startup { - requested_protocol_version: (protocol_version_major, protocol_version_minor), - parameters, - }))); - } - - if src.len() < MESSAGE_HEADER_SIZE { - src.reserve(MESSAGE_HEADER_SIZE); - return Ok(None); - } - - let mut header_buf = src.clone(); - let message_tag = header_buf.get_u8(); - let message_len = header_buf.get_i32() as usize; - - if src.len() < message_len { - src.reserve(message_len - src.len()); - return Ok(None); - } - - src.advance(MESSAGE_HEADER_SIZE); - - let read_cstr = |src: &mut BytesMut| -> Result { - let next_null = src - .iter() - .position(|&b| b == 0) - .ok_or(ProtocolError::ParserError)?; - let bytes = src[..next_null].to_owned(); - src.advance(bytes.len() + 1); - Ok(String::from_utf8(bytes)?) - }; - - let message = match message_tag { - b'P' => { - let prepared_statement_name = read_cstr(src)?; - let query = read_cstr(src)?; - let num_params = src.get_i16(); - let _params: Vec<_> = (0..num_params).map(|_| src.get_u32()).collect(); - - FrontendMessage::Parse(Parse { - prepared_statement_name, - query, - parameter_types: Vec::new(), - }) - } - b'D' => { - let target_type = src.get_u8(); - let name = read_cstr(src)?; - - FrontendMessage::Describe(match target_type { - b'P' => Describe::Portal(name), - b'S' => Describe::PreparedStatement(name), - _ => return Err(ProtocolError::ParserError), - }) - } - b'S' => FrontendMessage::Sync, - b'B' => { - let portal = read_cstr(src)?; - let prepared_statement_name = read_cstr(src)?; - - let num_param_format_codes = src.get_i16(); - - let mut format_codes: Vec = vec![]; - for _ in 0..num_param_format_codes { - format_codes.push(src.get_i16().try_into()?); - } - - let num_params = src.get_i16(); - let mut params = vec![]; - - let mut last_error = None; - - for i in 0..num_params { - let param_len = src.get_i32() as usize; - let format_code = if num_param_format_codes == 0 { - FormatCode::Text - } else if num_param_format_codes == 1 { - format_codes[0] - } else if format_codes.len() >= (i + 1) as usize { - format_codes[i as usize] - } else { - last_error = Some(ProtocolError::ParserError); - FormatCode::Text - }; - - let bytes = src.copy_to_bytes(param_len); - params.push(match format_code { - FormatCode::Binary => BindValue::Binary(bytes), - FormatCode::Text => match String::from_utf8(bytes.to_vec()) { - Ok(s) => BindValue::Text(s), - Err(e) => { - last_error = Some(ProtocolError::Utf8(e)); - continue; - } - }, - }); - } - - let result_format = match src.get_i16() { - 0 => BindFormat::All(FormatCode::Text), - 1 => BindFormat::All(src.get_i16().try_into()?), - n => { - let mut result_format_codes = Vec::new(); - for _ in 0..n { - result_format_codes.push(src.get_i16().try_into()?); - } - BindFormat::PerColumn(result_format_codes) - } - }; - - if let Some(e) = last_error { - return Err(e); - } - - FrontendMessage::Bind(Bind { - portal, - prepared_statement_name, - parameter_values: params, - result_format, - }) - } - b'E' => { - let portal = read_cstr(src)?; - let max_rows = match src.get_i32() { - 0 => None, - other => Some(other), - }; - - FrontendMessage::Execute(Execute { portal, max_rows }) - } - b'Q' => { - let query = read_cstr(src)?; - FrontendMessage::Query(query) - } - b'X' => FrontendMessage::Terminate, - other => return Err(ProtocolError::InvalidMessageType(other)), - }; - - Ok(Some(message)) - } -} - -impl Encoder for ConnectionCodec { - type Error = ProtocolError; - - fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> { - let mut body = BytesMut::new(); - item.encode(&mut body); - - dst.put_u8(T::TAG); - dst.put_i32((body.len() + 4) as i32); - dst.put_slice(&body); - Ok(()) - } -} - -pub struct SSLResponse(pub bool); - -impl Encoder for ConnectionCodec { - type Error = ProtocolError; - - fn encode(&mut self, item: SSLResponse, dst: &mut BytesMut) -> Result<(), Self::Error> { - dst.put_u8(if item.0 { b'S' } else { b'N' }); - Ok(()) - } -} diff --git a/crates/corro-pg/src/proto_ext.rs b/crates/corro-pg/src/proto_ext.rs deleted file mode 100644 index d7582a1d..00000000 --- a/crates/corro-pg/src/proto_ext.rs +++ /dev/null @@ -1,163 +0,0 @@ -//! Contains extensions that make working with the Postgres protocol simpler or more efficient. - -use crate::proto::{ConnectionCodec, FormatCode, ProtocolError, RowDescription}; -use bytes::{BufMut, BytesMut}; -use time::{Date, PrimitiveDateTime}; -use tokio_util::codec::Encoder; - -/// Supports batched rows for e.g. returning portal result sets. -/// -/// NB: this struct only performs limited validation of column consistency across rows. -pub struct DataRowBatch { - format_code: FormatCode, - num_cols: usize, - num_rows: usize, - data: BytesMut, - row: BytesMut, -} - -impl DataRowBatch { - /// Creates a new row batch using the given format code, requiring a certain number of columns per row. - pub fn new(format_code: FormatCode, num_cols: usize) -> Self { - Self { - format_code, - num_cols, - num_rows: 0, - data: BytesMut::new(), - row: BytesMut::new(), - } - } - - /// Creates a [DataRowBatch] from the given [RowDescription]. - pub fn from_row_desc(desc: &RowDescription) -> Self { - Self::new(desc.format_code, desc.fields.len()) - } - - /// Starts writing a new row. - /// - /// Returns a [DataRowWriter] that is responsible for the actual value encoding. - pub fn create_row(&mut self) -> DataRowWriter { - self.num_rows += 1; - DataRowWriter::new(self) - } - - /// Returns the number of rows currently written to this batch. - pub fn num_rows(&self) -> usize { - self.num_rows - } -} - -macro_rules! primitive_write { - ($name: ident, $type: ident) => { - #[allow(missing_docs)] - pub fn $name(&mut self, val: $type) { - match self.parent.format_code { - FormatCode::Text => self.write_value(&val.to_string().into_bytes()), - FormatCode::Binary => self.write_value(&val.to_be_bytes()), - }; - } - }; -} - -/// Temporarily leased from a [DataRowBatch] to encode a single row. -pub struct DataRowWriter<'a> { - current_col: usize, - parent: &'a mut DataRowBatch, -} - -impl<'a> DataRowWriter<'a> { - fn new(parent: &'a mut DataRowBatch) -> Self { - parent.row.put_i16(parent.num_cols as i16); - Self { - current_col: 0, - parent, - } - } - - fn write_value(&mut self, data: &[u8]) { - self.current_col += 1; - self.parent.row.put_i32(data.len() as i32); - self.parent.row.put_slice(data); - } - - /// Writes a null value for the next column. - pub fn write_null(&mut self) { - self.current_col += 1; - self.parent.row.put_i32(-1); - } - - /// Writes a string value for the next column. - pub fn write_string(&mut self, val: &str) { - self.write_value(val.as_bytes()); - } - - /// Writes a bool value for the next column. - pub fn write_bool(&mut self, val: bool) { - match self.parent.format_code { - FormatCode::Text => self.write_value(if val { "t" } else { "f" }.as_bytes()), - FormatCode::Binary => { - self.current_col += 1; - self.parent.row.put_u8(val as u8); - } - }; - } - - fn pg_date_epoch() -> Date { - Date::from_calendar_date(2000, time::Month::January, 1) - .expect("failed to create pg date epoch") - } - - fn pg_timestamp_epoch() -> PrimitiveDateTime { - Self::pg_date_epoch() - .with_hms(0, 0, 0) - .expect("failed to create pg timestamp epoch") - } - - /// Writes a date value for the next column. - pub fn write_date(&mut self, val: Date) { - match self.parent.format_code { - FormatCode::Binary => { - self.write_int4((val - Self::pg_date_epoch()).whole_days() as i32) - } - FormatCode::Text => self.write_string(&val.to_string()), - } - } - - /// Writes a timestamp value for the next column. - pub fn write_timestamp(&mut self, val: PrimitiveDateTime) { - match self.parent.format_code { - FormatCode::Binary => { - self.write_int8((val - Self::pg_timestamp_epoch()).whole_microseconds() as i64); - } - FormatCode::Text => self.write_string(&val.to_string()), - } - } - - primitive_write!(write_int2, i16); - primitive_write!(write_int4, i32); - primitive_write!(write_int8, i64); - primitive_write!(write_float4, f32); - primitive_write!(write_float8, f64); -} - -impl<'a> Drop for DataRowWriter<'a> { - fn drop(&mut self) { - assert_eq!( - self.parent.num_cols, self.current_col, - "dropped a row writer with an invalid number of columns" - ); - - self.parent.data.put_u8(b'D'); - self.parent.data.put_i32((self.parent.row.len() + 4) as i32); - self.parent.data.extend(self.parent.row.split()); - } -} - -impl Encoder for ConnectionCodec { - type Error = ProtocolError; - - fn encode(&mut self, item: DataRowBatch, dst: &mut BytesMut) -> Result<(), Self::Error> { - dst.extend(item.data); - Ok(()) - } -} From 4895431df4050cd9009d7985a8e7160e1e905956 Mon Sep 17 00:00:00 2001 From: Jerome Gravel-Niquet Date: Thu, 26 Oct 2023 15:25:14 -0400 Subject: [PATCH 08/12] send subscriptions too --- crates/corro-pg/Cargo.toml | 18 +++++++++--------- crates/corro-pg/src/lib.rs | 13 +++++++++++++ crates/corro-types/src/agent.rs | 22 +++++++++++++++++++++- 3 files changed, 43 insertions(+), 10 deletions(-) diff --git a/crates/corro-pg/Cargo.toml b/crates/corro-pg/Cargo.toml index 2fed1c7d..81d23011 100644 --- a/crates/corro-pg/Cargo.toml +++ b/crates/corro-pg/Cargo.toml @@ -7,24 +7,24 @@ edition = "2021" bytes = { workspace = true } compact_str = { workspace = true } corro-types = { path = "../corro-types" } +fallible-iterator = { workspace = true } futures = { workspace = true } +pgwire = { version = "0.16.1" } +phf = "*" +postgres-types = { version = "0.2", features = ["with-time-0_3"] } rusqlite = { workspace = true } +spawn = { path = "../spawn" } +sqlite3-parser = { workspace = true } sqlparser = { version = "0.38" } -pgwire = { version = "0.16.1" } +tempfile = { workspace = true } thiserror = { workspace = true } +time = { workspace = true } tokio = { workspace = true } tokio-util = { workspace = true } tracing = { workspace = true } -time = { workspace = true } -phf = "*" -postgres-types = { version = "0.2", features = ["with-time-0_3"] } -sqlite3-parser = { workspace = true } -fallible-iterator = { workspace = true } tripwire = { path = "../tripwire" } -tempfile = { workspace = true } [dev-dependencies] -tracing-subscriber = { workspace = true } corro-tests = { path = "../corro-tests" } tokio-postgres = { version = "0.7.10" } -spawn = { path = "../spawn" } \ No newline at end of file +tracing-subscriber = { workspace = true } \ No newline at end of file diff --git a/crates/corro-pg/src/lib.rs b/crates/corro-pg/src/lib.rs index 8d81a8cd..be744476 100644 --- a/crates/corro-pg/src/lib.rs +++ b/crates/corro-pg/src/lib.rs @@ -15,6 +15,7 @@ use corro_types::{ broadcast::Timestamp, config::PgConfig, schema::{parse_sql, Schema, SchemaError, SqliteType, Table}, + sqlite::SqlitePoolError, }; use fallible_iterator::FallibleIterator; use futures::{SinkExt, StreamExt}; @@ -37,6 +38,7 @@ use pgwire::{ }; use postgres_types::{FromSql, Type}; use rusqlite::{named_params, types::ValueRef, vtab::eponymous_only_module, Connection, Statement}; +use spawn::spawn_counted; use sqlite3_parser::ast::{ As, Cmd, CreateTableBody, Expr, FromClause, Id, InsertBody, Name, OneSelect, ResultColumn, Select, SelectTable, Stmt, @@ -1667,6 +1669,17 @@ fn handle_commit(agent: &Agent, conn: &Connection, commit_stmt: &str) -> rusqlit }, ); + drop(book_writer); + + spawn_counted({ + let agent = agent.clone(); + async move { + let conn = agent.pool().read().await?; + block_in_place(|| agent.process_subs_by_db_version(&conn, db_version)); + Ok::<_, SqlitePoolError>(()) + } + }); + Ok(()) } diff --git a/crates/corro-types/src/agent.rs b/crates/corro-types/src/agent.rs index fd78ea75..16545bf6 100644 --- a/crates/corro-types/src/agent.rs +++ b/crates/corro-types/src/agent.rs @@ -32,7 +32,7 @@ use tokio::{ }, }; use tokio_util::sync::{CancellationToken, DropGuard}; -use tracing::{debug, error, info, Instrument}; +use tracing::{debug, error, info, trace, Instrument}; use tripwire::Tripwire; use crate::{ @@ -194,6 +194,26 @@ impl Agent { pub fn limits(&self) -> &Limits { &self.0.limits } + + pub fn process_subs_by_db_version(&self, conn: &Connection, db_version: i64) { + trace!("process subs by db version..."); + + let mut matchers_to_delete = vec![]; + + { + let matchers = self.matchers().read(); + for (id, matcher) in matchers.iter() { + if let Err(e) = matcher.process_changes_from_db_version(conn, db_version) { + error!("could not process change w/ matcher {id}, it is probably defunct! {e}"); + matchers_to_delete.push(*id); + } + } + } + + for id in matchers_to_delete { + self.matchers().write().remove(&id); + } + } } #[derive(Debug, Clone)] From 8e824d9ec5eb4ae8d77de4b4bc3c2bde99055907 Mon Sep 17 00:00:00 2001 From: Jerome Gravel-Niquet Date: Thu, 26 Oct 2023 17:48:30 -0400 Subject: [PATCH 09/12] need to unquote a few things sometimes --- Cargo.lock | 1 + crates/corro-pg/Cargo.toml | 1 + crates/corro-pg/src/lib.rs | 100 +++++++++++++++++++++---------------- 3 files changed, 59 insertions(+), 43 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8ebfa659..e2bc472c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -830,6 +830,7 @@ dependencies = [ "compact_str 0.7.0", "corro-tests", "corro-types", + "enquote", "fallible-iterator", "futures", "pgwire", diff --git a/crates/corro-pg/Cargo.toml b/crates/corro-pg/Cargo.toml index 81d23011..6b7b1418 100644 --- a/crates/corro-pg/Cargo.toml +++ b/crates/corro-pg/Cargo.toml @@ -23,6 +23,7 @@ tokio = { workspace = true } tokio-util = { workspace = true } tracing = { workspace = true } tripwire = { path = "../tripwire" } +enquote = { workspace = true } [dev-dependencies] corro-tests = { path = "../corro-tests" } diff --git a/crates/corro-pg/src/lib.rs b/crates/corro-pg/src/lib.rs index be744476..1c49be6f 100644 --- a/crates/corro-pg/src/lib.rs +++ b/crates/corro-pg/src/lib.rs @@ -578,15 +578,18 @@ pub async fn start( } } - back_tx.blocking_send( - ( - PgWireBackendMessage::ParameterDescription( - ParameterDescription::new(oids), - ), - false, - ) - .into(), - )?; + if !oids.is_empty() { + back_tx.blocking_send( + ( + PgWireBackendMessage::ParameterDescription( + ParameterDescription::new(oids), + ), + false, + ) + .into(), + )?; + } + back_tx.blocking_send( ( PgWireBackendMessage::RowDescription( @@ -1466,18 +1469,17 @@ pub async fn start( portals.retain(|_, (stmt_name, _, _, _)| { stmt_name.as_str() != name }); - back_tx.blocking_send( - ( - PgWireBackendMessage::CloseComplete( - CloseComplete::new(), - ), - true, - ) - .into(), - )?; - continue; } - // not finding a statement is not an error + back_tx.blocking_send( + ( + PgWireBackendMessage::CloseComplete( + CloseComplete::new(), + ), + true, + ) + .into(), + )?; + continue; } // portal b'P' => { @@ -1828,8 +1830,16 @@ fn extract_param( SqliteName::Name(_) => {} SqliteName::Qualified(tbl_name, col_name) | SqliteName::DoublyQualified(_, tbl_name, col_name) => { + println!("looking tbl {} for col {}", tbl_name.0, col_name.0); if let Some(table) = tables.get(&tbl_name.0) { - if let Some(col) = table.columns.get(&col_name.0) { + println!("found table! {}", table.name); + if let Ok(unquoted) = enquote::unquote(&col_name.0) { + println!("unquoted column as: {unquoted}"); + if let Some(col) = table.columns.get(&unquoted) { + params.push(col.sql_type); + } + } else if let Some(col) = table.columns.get(&col_name.0) { + println!("could not unquote, using original"); params.push(col.sql_type); } } @@ -2020,18 +2030,20 @@ fn handle_from<'a>( if let Some(select) = from.select.as_deref() { match select { SelectTable::Table(qname, maybe_alias, _) => { - if let Some(alias) = maybe_alias { - let alias = match alias { - As::As(name) => name.0.clone(), - As::Elided(name) => name.0.clone(), - }; - if let Some(table) = schema.tables.get(&qname.name.0) { - tables.insert(alias, table); - } + let maybe_table = if let Ok(unquoted) = enquote::unquote(&qname.name.0) { + schema.tables.get(&unquoted) } else { - // not aliased - if let Some(table) = schema.tables.get(&qname.name.0) { - tables.insert(qname.name.0.clone(), table); + schema.tables.get(&qname.name.0) + }; + + if let Some(table) = maybe_table { + if let Some(alias) = maybe_alias { + let alias = match alias { + As::As(name) | As::Elided(name) => name.0.clone(), + }; + tables.insert(alias, table); + } else { + tables.insert(table.name.clone(), table); } } } @@ -2046,18 +2058,20 @@ fn handle_from<'a>( for join in joins.iter() { match &join.table { SelectTable::Table(qname, maybe_alias, _) => { - if let Some(alias) = maybe_alias { - let alias = match alias { - As::As(name) => name.0.clone(), - As::Elided(name) => name.0.clone(), - }; - if let Some(table) = schema.tables.get(&qname.name.0) { - tables.insert(alias, table); - } + let maybe_table = if let Ok(unquoted) = enquote::unquote(&qname.name.0) { + schema.tables.get(&unquoted) } else { - // not aliased - if let Some(table) = schema.tables.get(&qname.name.0) { - tables.insert(qname.name.0.clone(), table); + schema.tables.get(&qname.name.0) + }; + + if let Some(table) = maybe_table { + if let Some(alias) = maybe_alias { + let alias = match alias { + As::As(name) | As::Elided(name) => name.0.clone(), + }; + tables.insert(alias, table); + } else { + tables.insert(table.name.clone(), table); } } } From 5a7bc22864a6e37903a75ce7565afa6f41285e31 Mon Sep 17 00:00:00 2001 From: Jerome Gravel-Niquet Date: Fri, 27 Oct 2023 11:22:57 -0400 Subject: [PATCH 10/12] handle interrupts, fix last few bugs for compliant clients --- crates/corro-pg/src/lib.rs | 1262 ++++++++++++++++++++---------------- 1 file changed, 709 insertions(+), 553 deletions(-) diff --git a/crates/corro-pg/src/lib.rs b/crates/corro-pg/src/lib.rs index 1c49be6f..e742c093 100644 --- a/crates/corro-pg/src/lib.rs +++ b/crates/corro-pg/src/lib.rs @@ -26,7 +26,7 @@ use pgwire::{ }, error::{ErrorInfo, PgWireError}, messages::{ - data::{ParameterDescription, RowDescription}, + data::{NoData, ParameterDescription, RowDescription}, extendedquery::{BindComplete, CloseComplete, ParseComplete, PortalSuspended}, response::{ EmptyQueryResponse, ReadyForQuery, READY_STATUS_IDLE, READY_STATUS_TRANSACTION_BLOCK, @@ -49,7 +49,7 @@ use tokio::{ sync::mpsc::channel, task::block_in_place, }; -use tokio_util::codec::Framed; +use tokio_util::{codec::Framed, sync::CancellationToken}; use tracing::{debug, error, info, trace, warn}; use tripwire::{Outcome, PreemptibleFutureExt, Tripwire}; @@ -78,28 +78,99 @@ impl From<(PgWireBackendMessage, bool)> for BackendResponse { } } +#[derive(Clone, Copy, Debug)] +enum StmtTag { + Select, + InsertAsSelect, + + Insert, + Update, + Delete, + + Alter, + Analyze, + Attach, + Begin, + Commit, + Create, + Detach, + Drop, + Pragma, + Reindex, + Release, + Rollback, + Savepoint, + Vacuum, + + Other, +} + +impl StmtTag { + fn returns_rows_affected(&self) -> bool { + matches!(self, StmtTag::Insert | StmtTag::Update | StmtTag::Delete) + } + fn returns_num_rows(&self) -> bool { + matches!(self, StmtTag::Select | StmtTag::InsertAsSelect) + } + pub fn tag(&self, rows: Option) -> Tag { + match self { + StmtTag::Select => Tag::new_for_execution("SELECT", rows), + StmtTag::InsertAsSelect | StmtTag::Insert => Tag::new_for_execution("INSERT", rows), + StmtTag::Update => Tag::new_for_execution("UPDATE", rows), + StmtTag::Delete => Tag::new_for_execution("DELETE", rows), + StmtTag::Alter => Tag::new_for_execution("ALTER", rows), + StmtTag::Analyze => Tag::new_for_execution("ANALYZE", rows), + StmtTag::Attach => Tag::new_for_execution("ATTACH", rows), + StmtTag::Begin => Tag::new_for_execution("BEGIN", rows), + StmtTag::Commit => Tag::new_for_execution("COMMIT", rows), + StmtTag::Create => Tag::new_for_execution("CREATE", rows), + StmtTag::Detach => Tag::new_for_execution("DETACH", rows), + StmtTag::Drop => Tag::new_for_execution("DROP", rows), + StmtTag::Pragma => Tag::new_for_execution("PRAGMA", rows), + StmtTag::Reindex => Tag::new_for_execution("REINDEX", rows), + StmtTag::Release => Tag::new_for_execution("RELEASE", rows), + StmtTag::Rollback => Tag::new_for_execution("ROLLBACK", rows), + StmtTag::Savepoint => Tag::new_for_execution("SAVEPOINT", rows), + StmtTag::Vacuum => Tag::new_for_execution("VACUUM", rows), + StmtTag::Other => Tag::new_for_execution("OK", rows), + } + } +} + +enum Prepared { + Empty, + NonEmpty { + sql: String, + param_types: Vec, + fields: Vec, + tag: StmtTag, + }, +} + +enum Portal<'a> { + Empty { + stmt_name: CompactString, + }, + Parsed { + stmt_name: CompactString, + stmt: Statement<'a>, + result_formats: Vec, + tag: StmtTag, + }, +} + +impl<'a> Portal<'a> { + fn stmt_name(&self) -> &str { + match self { + Portal::Empty { stmt_name } | Portal::Parsed { stmt_name, .. } => stmt_name.as_str(), + } + } +} + #[derive(Clone, Debug)] struct ParsedCmd(Cmd); impl ParsedCmd { - pub fn returns_rows_affected(&self) -> bool { - matches!( - self.0, - Cmd::Stmt(Stmt::Insert { .. }) - | Cmd::Stmt(Stmt::Update { .. }) - | Cmd::Stmt(Stmt::Delete { .. }) - ) - } - pub fn returns_num_rows(&self) -> bool { - matches!( - self.0, - Cmd::Stmt(Stmt::Select(_)) - | Cmd::Stmt(Stmt::CreateTable { - body: CreateTableBody::AsSelect(_), - .. - }) - ) - } pub fn is_begin(&self) -> bool { matches!(self.0, Cmd::Stmt(Stmt::Begin(_, _))) } @@ -110,41 +181,41 @@ impl ParsedCmd { matches!(self.0, Cmd::Stmt(Stmt::Rollback { .. })) } - fn tag(&self, rows: Option) -> Tag { + fn tag(&self) -> StmtTag { match &self.0 { Cmd::Stmt(stmt) => match stmt { - Stmt::Select(_) - | Stmt::CreateTable { + Stmt::Select(_) => StmtTag::Select, + Stmt::CreateTable { body: CreateTableBody::AsSelect(_), .. - } => Tag::new_for_query(rows.unwrap_or_default()), - Stmt::AlterTable(_, _) => Tag::new_for_execution("ALTER", rows), - Stmt::Analyze(_) => Tag::new_for_execution("ANALYZE", rows), - Stmt::Attach { .. } => Tag::new_for_execution("ATTACH", rows), - Stmt::Begin(_, _) => Tag::new_for_execution("BEGIN", rows), - Stmt::Commit(_) => Tag::new_for_execution("COMMIT", rows), + } => StmtTag::InsertAsSelect, + Stmt::AlterTable(_, _) => StmtTag::Alter, + Stmt::Analyze(_) => StmtTag::Analyze, + Stmt::Attach { .. } => StmtTag::Attach, + Stmt::Begin(_, _) => StmtTag::Begin, + Stmt::Commit(_) => StmtTag::Commit, Stmt::CreateIndex { .. } | Stmt::CreateTable { .. } | Stmt::CreateTrigger { .. } | Stmt::CreateView { .. } - | Stmt::CreateVirtualTable { .. } => Tag::new_for_execution("CREATE", rows), - Stmt::Delete { .. } => Tag::new_for_execution("DELETE", rows), - Stmt::Detach(_) => Tag::new_for_execution("DETACH", rows), + | Stmt::CreateVirtualTable { .. } => StmtTag::Create, + Stmt::Delete { .. } => StmtTag::Delete, + Stmt::Detach(_) => StmtTag::Detach, Stmt::DropIndex { .. } | Stmt::DropTable { .. } | Stmt::DropTrigger { .. } - | Stmt::DropView { .. } => Tag::new_for_execution("DROP", rows), - Stmt::Insert { .. } => Tag::new_for_execution("INSERT", rows), - Stmt::Pragma(_, _) => Tag::new_for_execution("PRAGMA", rows), - Stmt::Reindex { .. } => Tag::new_for_execution("REINDEX", rows), - Stmt::Release(_) => Tag::new_for_execution("RELEASE", rows), - Stmt::Rollback { .. } => Tag::new_for_execution("ROLLBACK", rows), - Stmt::Savepoint(_) => Tag::new_for_execution("SAVEPOINT", rows), - - Stmt::Update { .. } => Tag::new_for_execution("UPDATE", rows), - Stmt::Vacuum(_, _) => Tag::new_for_execution("VACUUM", rows), + | Stmt::DropView { .. } => StmtTag::Drop, + Stmt::Insert { .. } => StmtTag::Insert, + Stmt::Pragma(_, _) => StmtTag::Pragma, + Stmt::Reindex { .. } => StmtTag::Reindex, + Stmt::Release(_) => StmtTag::Release, + Stmt::Rollback { .. } => StmtTag::Rollback, + Stmt::Savepoint(_) => StmtTag::Savepoint, + + Stmt::Update { .. } => StmtTag::Update, + Stmt::Vacuum(_, _) => StmtTag::Vacuum, }, - _ => Tag::new_for_execution("OK", rows), + _ => StmtTag::Other, } } } @@ -241,7 +312,7 @@ pub async fn start( tokio::spawn(async move { conn.set_nodelay(true)?; let ssl = peek_for_sslrequest(&mut conn, false).await?; - println!("SSL? {ssl}"); + trace!("SSL? {ssl}"); let mut framed = Framed::new( conn, @@ -249,7 +320,7 @@ pub async fn start( ); let msg = framed.next().await.unwrap()?; - println!("msg: {msg:?}"); + trace!("msg: {msg:?}"); match msg { PgWireFrontendMessage::Startup(startup) => { @@ -293,18 +364,26 @@ pub async fn start( framed.flush().await?; - println!("sent auth ok and ReadyForQuery"); + trace!("sent auth ok and ReadyForQuery"); let (front_tx, mut front_rx) = channel(1024); let (back_tx, mut back_rx) = channel(1024); let (mut sink, mut stream) = framed.split(); + let conn = agent.pool().client_dedicated().unwrap(); + trace!("opened connection"); + + let cancel = CancellationToken::new(); + tokio::spawn({ let back_tx = back_tx.clone(); + let cancel = cancel.clone(); async move { + // cancel stuff if this loop breaks + let _drop_guard = cancel.drop_guard(); + while let Some(decode_res) = stream.next().await { - println!("decode_res: {decode_res:?}"); let msg = match decode_res { Ok(msg) => msg, Err(PgWireError::IoError(io_error)) => { @@ -312,6 +391,7 @@ pub async fn start( break; } Err(e) => { + warn!("could not receive pg frontend message: {e}"); // attempt to send this... _ = back_tx.try_send( ( @@ -333,33 +413,42 @@ pub async fn start( front_tx.send(msg).await?; } + debug!("frontend stream is done"); Ok::<_, BoxError>(()) } }); - tokio::spawn(async move { - while let Some(back) = back_rx.recv().await { - match back { - BackendResponse::Message { message, flush } => { - println!("sending: {message:?}"); - sink.feed(message).await?; - if flush { + tokio::spawn({ + let cancel = cancel.clone(); + async move { + let _drop_guard = cancel.drop_guard(); + while let Some(back) = back_rx.recv().await { + match back { + BackendResponse::Message { message, flush } => { + debug!("sending: {message:?}"); + sink.feed(message).await?; + if flush { + sink.flush().await?; + } + } + BackendResponse::Flush => { sink.flush().await?; } } - BackendResponse::Flush => { - sink.flush().await?; - } } + debug!("backend stream is done"); + Ok::<_, std::io::Error>(()) } - Ok::<_, std::io::Error>(()) }); - block_in_place(|| { - let conn = agent.pool().client_dedicated().unwrap(); - println!("opened connection"); + let int_handle = conn.get_interrupt_handle(); + tokio::spawn(async move { + cancel.cancelled().await; + int_handle.interrupt(); + }); + block_in_place(|| { conn.create_module("pg_type", eponymous_only_module::(), None)?; conn.create_module("pg_range", eponymous_only_module::(), None)?; @@ -385,25 +474,14 @@ pub async fn start( } }; - let mut prepared: HashMap< - CompactString, - (String, Option, Statement, Vec), - > = HashMap::new(); - - let mut portals: HashMap< - CompactString, - ( - CompactString, - Option, - Statement, - Vec, - ), - > = HashMap::new(); + let mut prepared: HashMap = HashMap::new(); + + let mut portals: HashMap = HashMap::new(); let mut open_tx = None; 'outer: while let Some(msg) = front_rx.blocking_recv() { - println!("msg: {msg:?}"); + debug!("msg: {msg:?}"); match msg { PgWireFrontendMessage::Startup(_) => { @@ -424,6 +502,7 @@ pub async fn start( continue; } PgWireFrontendMessage::Parse(parse) => { + let name: &str = parse.name().as_deref().unwrap_or(""); let mut cmds = match parse_query(parse.query()) { Ok(cmds) => cmds, Err(e) => { @@ -445,66 +524,113 @@ pub async fn start( } }; - let parsed_cmd = match cmds.pop_front() { - Some(cmd) => cmd, + match cmds.pop_front() { None => { - back_tx.blocking_send( - ( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".to_owned(), - sql_state::SqlState::PROTOCOL_VIOLATION - .code() - .into(), - "only 1 command per Parse is allowed" - .into(), - ) - .into(), - ), - true, - ) - .into(), - )?; - continue; + prepared.insert(name.into(), Prepared::Empty); } - }; + Some(parsed_cmd) => { + if !cmds.is_empty() { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + sql_state::SqlState::PROTOCOL_VIOLATION + .code() + .into(), + "only 1 command per Parse is allowed" + .into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } - println!("parsed cmd: {parsed_cmd:?}"); + trace!("parsed cmd: {parsed_cmd:#?}"); - let prepped = match conn.prepare(parse.query()) { - Ok(prepped) => prepped, - Err(e) => { - back_tx.blocking_send( - ( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - e.to_string(), + let prepped = match conn.prepare(parse.query()) { + Ok(prepped) => prepped, + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, ) - .into(), - ), - true, - ) - .into(), - )?; - continue; - } - }; + .into(), + )?; + continue; + } + }; - prepared.insert( - parse.name().as_deref().unwrap_or("").into(), - ( - parse.query().clone(), - Some(parsed_cmd), - prepped, - parse + let mut param_types: Vec = parse .type_oids() .iter() .filter_map(|oid| Type::from_oid(*oid)) - .collect(), - ), - ); + .collect(); + + if param_types.len() != prepped.parameter_count() { + param_types = parameter_types(&schema, &parsed_cmd.0) + .into_iter() + .map(|param| match param { + SqliteType::Null => unreachable!(), + SqliteType::Integer => Type::INT8, + SqliteType::Real => Type::FLOAT8, + SqliteType::Text => Type::TEXT, + SqliteType::Blob => Type::BYTEA, + }) + .collect(); + } + + let mut fields = vec![]; + for col in prepped.columns() { + let col_type = match name_to_type( + col.decl_type().unwrap_or("text"), + ) { + Ok(t) => t, + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + e.into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + }; + fields.push(FieldInfo::new( + col.name().to_string(), + None, + None, + col_type, + FieldFormat::Text, + )); + } + + prepared.insert( + name.into(), + Prepared::NonEmpty { + sql: parse.query().clone(), + param_types, + fields, + tag: parsed_cmd.tag(), + }, + ); + } + } back_tx.blocking_send( ( @@ -518,77 +644,51 @@ pub async fn start( let name = desc.name().as_deref().unwrap_or(""); match desc.target_type() { // statement - b'S' => { - if let Some((_, cmd, prepped, param_types)) = - prepared.get(name) - { - let mut oids = vec![]; - let mut fields = vec![]; - for col in prepped.columns() { - let col_type = - match name_to_type( - col.decl_type().unwrap_or("text"), - ) { - Ok(t) => t, - Err(e) => { - back_tx.blocking_send(( - PgWireBackendMessage::ErrorResponse( - e.into(), - ), - true, - ).into())?; - continue 'outer; - } - }; - fields.push(FieldInfo::new( - col.name().to_string(), - None, - None, - col_type, - FieldFormat::Text, - )); - } - - if param_types.len() != prepped.parameter_count() { - if let Some(ParsedCmd(Cmd::Stmt(stmt))) = &cmd { - let params = parameter_types(&schema, stmt); - println!("GOT PARAMS TO OVERRIDE: {params:?}"); - for param in params { - oids.push(match param { - SqliteType::Null => unreachable!(), - SqliteType::Integer => Type::INT8.oid(), - SqliteType::Real => Type::FLOAT8.oid(), - SqliteType::Text => Type::TEXT.oid(), - SqliteType::Blob => Type::BYTEA.oid(), - }) - } - } - } else { - for param in 0..prepped.parameter_count() { - // if let Some(t) = param_types.get(param) { - // oids.push(t.oid()); - // } - oids.push( - param_types - .get(param) - .map(|t| t.oid()) - // this should not happen... - .unwrap_or(Type::TEXT.oid()), - ); - } - } - - if !oids.is_empty() { - back_tx.blocking_send( - ( - PgWireBackendMessage::ParameterDescription( - ParameterDescription::new(oids), - ), - false, - ) + b'S' => match prepared.get(name) { + None => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".into(), + "statement not found".into(), + ) .into(), - )?; - } + ), + true, + ) + .into(), + )?; + } + Some(Prepared::Empty) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::NoData(NoData::new()), + false, + ) + .into(), + )?; + } + Some(Prepared::NonEmpty { + param_types, + fields, + .. + }) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ParameterDescription( + ParameterDescription::new( + param_types + .iter() + .map(|t| t.oid()) + .collect(), + ), + ), + false, + ) + .into(), + )?; back_tx.blocking_send( ( @@ -601,33 +701,43 @@ pub async fn start( ) .into(), )?; - continue; } - back_tx.blocking_send( - ( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".into(), - "XX000".into(), - "statement not found".into(), - ) - .into(), - ), - true, - ) - .into(), - )?; - } + }, // portal - b'P' => { - if let Some((_, _, prepped, result_formats)) = - portals.get(name) - { + b'P' => match portals.get(name) { + None => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".into(), + "portal not found".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + } + Some(Portal::Empty { .. }) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::NoData(NoData::new()), + false, + ) + .into(), + )?; + } + Some(Portal::Parsed { + stmt, + result_formats, + .. + }) => { let mut oids = vec![]; let mut fields = vec![]; - for (i, col) in - prepped.columns().into_iter().enumerate() - { + for (i, col) in stmt.columns().into_iter().enumerate() { let col_type = match name_to_type( col.decl_type().unwrap_or("text"), @@ -666,23 +776,8 @@ pub async fn start( ) .into(), )?; - continue; } - back_tx.blocking_send( - ( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".into(), - "XX000".into(), - "portal not found".into(), - ) - .into(), - ), - true, - ) - .into(), - )?; - } + }, _ => { back_tx.blocking_send( ( @@ -711,7 +806,7 @@ pub async fn start( let stmt_name = bind.statement_name().as_deref().unwrap_or(""); - let (sql, query, _, param_types) = match prepared.get(stmt_name) { + match prepared.get(stmt_name) { None => { back_tx.blocking_send( ( @@ -729,80 +824,23 @@ pub async fn start( )?; continue; } - Some(stmt) => stmt, - }; - - let mut prepped = match conn.prepare(sql) { - Ok(prepped) => prepped, - Err(e) => { - back_tx.blocking_send( - ( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - e.to_string(), - ) - .into(), - ), - true, - ) - .into(), - )?; - continue; - } - }; - - let param_types = if bind.parameters().len() != param_types.len() { - if let Some(ParsedCmd(Cmd::Stmt(stmt))) = &query { - let params = parameter_types(&schema, stmt); - println!("computed params: {params:?}",); - params - .iter() - .map(|param| match param { - SqliteType::Null => unreachable!(), - SqliteType::Integer => Type::INT8, - SqliteType::Real => Type::FLOAT8, - SqliteType::Text => Type::TEXT, - SqliteType::Blob => Type::BYTEA, - }) - .collect() - } else { - back_tx.blocking_send( - ( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - "could not determine parameter type".into(), - ) - .into(), - ), - true, - ) - .into(), - )?; - continue 'outer; + Some(Prepared::Empty) => { + portals.insert( + portal_name, + Portal::Empty { + stmt_name: stmt_name.into(), + }, + ); } - } else { - param_types.clone() - }; - - println!("CMD: {query:?}"); - if bind.parameters().len() != param_types.len() { - if let Some(ParsedCmd(Cmd::Stmt(stmt))) = &query { - let params = parameter_types(&schema, stmt); - println!("computed params: {params:?}",); - } - } - - for (i, param) in bind.parameters().iter().enumerate() { - let idx = i + 1; - let b = match param { - None => { - if let Err(e) = prepped - .raw_bind_parameter(idx, rusqlite::types::Null) - { + Some(Prepared::NonEmpty { + sql, + param_types, + tag, + .. + }) => { + let mut prepped = match conn.prepare(sql) { + Ok(prepped) => prepped, + Err(e) => { back_tx.blocking_send( ( PgWireBackendMessage::ErrorResponse( @@ -817,110 +855,182 @@ pub async fn start( ) .into(), )?; - continue 'outer; + continue; } - continue; - } - Some(b) => b, - }; + }; - match param_types.get(i) { - None => { - back_tx.blocking_send( - ( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - "missing parameter type".into(), - ) - .into(), - ), - true, - ) - .into(), - )?; - continue 'outer; - } - Some(param_type) => match param_type { - &Type::BOOL => { - let value: bool = - FromSql::from_sql(param_type, b.as_ref())?; - prepped.raw_bind_parameter(idx, value)?; - } - &Type::INT2 => { - let value: i16 = - FromSql::from_sql(param_type, b.as_ref())?; - prepped.raw_bind_parameter(idx, value)?; - } - &Type::INT4 => { - let value: i32 = - FromSql::from_sql(param_type, b.as_ref())?; - prepped.raw_bind_parameter(idx, value)?; - } - &Type::INT8 => { - let value: i64 = - FromSql::from_sql(param_type, b.as_ref())?; - prepped.raw_bind_parameter(idx, value)?; - } - &Type::TEXT | &Type::VARCHAR => { - let value: &str = - FromSql::from_sql(param_type, b.as_ref())?; - prepped.raw_bind_parameter(idx, value)?; - } - &Type::FLOAT4 => { - let value: f32 = - FromSql::from_sql(param_type, b.as_ref())?; - prepped.raw_bind_parameter(idx, value)?; - } - &Type::FLOAT8 => { - let value: f64 = - FromSql::from_sql(param_type, b.as_ref())?; - prepped.raw_bind_parameter(idx, value)?; - } - &Type::BYTEA => { - let value: &[u8] = - FromSql::from_sql(param_type, b.as_ref())?; - prepped.raw_bind_parameter(idx, value)?; - } - t => { - warn!("unsupported type: {t:?}"); - back_tx.blocking_send( - ( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!( - "unsupported type {t} at index {i}" - ), + trace!( + "bind params count: {}, statement params count: {}", + bind.parameters().len(), + prepped.parameter_count() + ); + + for (i, param) in bind.parameters().iter().enumerate() { + let idx = i + 1; + let b = match param { + None => { + trace!("binding idx {idx} w/ NULL"); + if let Err(e) = prepped.raw_bind_parameter( + idx, + rusqlite::types::Null, + ) { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, ) + .into(), + )?; + continue 'outer; + } + continue; + } + Some(b) => b, + }; + + match param_types.get(i) { + None => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + "missing parameter type".into(), + ) + .into(), + ), + true, + ) .into(), - ), - true, - ) - .into(), - )?; - continue 'outer; + )?; + continue 'outer; + } + Some(param_type) => { + match param_type { + &Type::BOOL => { + let value: bool = FromSql::from_sql( + param_type, + b.as_ref(), + )?; + trace!("binding idx {idx} w/ value: {value}"); + prepped + .raw_bind_parameter(idx, value)?; + } + &Type::INT2 => { + let value: i16 = FromSql::from_sql( + param_type, + b.as_ref(), + )?; + trace!("binding idx {idx} w/ value: {value}"); + prepped + .raw_bind_parameter(idx, value)?; + } + &Type::INT4 => { + let value: i32 = FromSql::from_sql( + param_type, + b.as_ref(), + )?; + trace!("binding idx {idx} w/ value: {value}"); + prepped + .raw_bind_parameter(idx, value)?; + } + &Type::INT8 => { + let value: i64 = FromSql::from_sql( + param_type, + b.as_ref(), + )?; + trace!("binding idx {idx} w/ value: {value}"); + prepped + .raw_bind_parameter(idx, value)?; + } + &Type::TEXT | &Type::VARCHAR => { + let value: &str = FromSql::from_sql( + param_type, + b.as_ref(), + )?; + trace!("binding idx {idx} w/ value: {value}"); + prepped + .raw_bind_parameter(idx, value)?; + } + &Type::FLOAT4 => { + let value: f32 = FromSql::from_sql( + param_type, + b.as_ref(), + )?; + trace!("binding idx {idx} w/ value: {value}"); + prepped + .raw_bind_parameter(idx, value)?; + } + &Type::FLOAT8 => { + let value: f64 = FromSql::from_sql( + param_type, + b.as_ref(), + )?; + trace!("binding idx {idx} w/ value: {value}"); + prepped + .raw_bind_parameter(idx, value)?; + } + &Type::BYTEA => { + let value: &[u8] = FromSql::from_sql( + param_type, + b.as_ref(), + )?; + trace!("binding idx {idx} w/ value: {value:?}"); + prepped + .raw_bind_parameter(idx, value)?; + } + t => { + warn!("unsupported type: {t:?}"); + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + format!( + "unsupported type {t} at index {i}" + ), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + } + } } - }, + } + + debug!("EXPANDED SQL: {:?}", prepped.expanded_sql()); + + portals.insert( + portal_name, + Portal::Parsed { + stmt_name: stmt_name.into(), + stmt: prepped, + result_formats: bind + .result_column_format_codes() + .iter() + .copied() + .map(FieldFormat::from) + .collect(), + tag: *tag, + }, + ); } } - portals.insert( - portal_name, - ( - stmt_name.into(), - query.clone(), - prepped, - bind.result_column_format_codes() - .iter() - .copied() - .map(FieldFormat::from) - .collect(), - ), - ); - back_tx.blocking_send( ( PgWireBackendMessage::BindComplete(BindComplete::new()), @@ -947,42 +1057,48 @@ pub async fn start( } PgWireFrontendMessage::Execute(execute) => { let name = execute.name().as_deref().unwrap_or(""); - let (parsed_cmd, prepped, result_formats) = - match portals.get_mut(name) { - Some((_, Some(parsed_cmd), prepped, result_formats)) => { - (parsed_cmd, prepped, result_formats) - } - Some((_, None, _, _)) => { - back_tx.blocking_send( - ( - PgWireBackendMessage::EmptyQueryResponse( - EmptyQueryResponse::new(), - ), - false, - ) - .into(), - )?; - continue; - } - None => { - back_tx.blocking_send( - ( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".into(), - "XX000".into(), - "portal not found".into(), - ) - .into(), - ), - true, - ) + let (prepped, result_formats, tag) = match portals.get_mut(name) { + Some(Portal::Empty { .. }) => { + trace!("empty portal"); + back_tx.blocking_send( + ( + PgWireBackendMessage::EmptyQueryResponse( + EmptyQueryResponse::new(), + ), + false, + ) + .into(), + )?; + continue; + } + Some(Portal::Parsed { + stmt, + result_formats, + tag, + .. + }) => (stmt, result_formats, tag), + None => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".into(), + "portal not found".into(), + ) .into(), - )?; - continue; - } - }; + ), + true, + ) + .into(), + )?; + continue; + } + }; + trace!("non-empty portal!"); + + // TODO: maybe we don't need to recompute this... let mut fields = vec![]; for (i, col) in prepped.columns().into_iter().enumerate() { let col_type = @@ -1010,6 +1126,8 @@ pub async fn start( )); } + trace!("fields: {fields:?}"); + let schema = Arc::new(fields); let mut rows = prepped.raw_query(); @@ -1023,8 +1141,11 @@ pub async fn start( }; let mut count = 0; + trace!("starting loop"); + loop { if count >= max_rows { + trace!("attained max rows"); // forget the Rows iterator here so as to not reset the statement! std::mem::forget(rows); back_tx.blocking_send( @@ -1039,9 +1160,12 @@ pub async fn start( continue 'outer; } let row = match rows.next() { - Ok(Some(row)) => row, + Ok(Some(row)) => { + trace!("got a row: {row:?}"); + row + } Ok(None) => { - println!("done w/ rows"); + trace!("done w/ rows"); break; } Err(e) => { @@ -1115,12 +1239,14 @@ pub async fn start( } } - let tag = if parsed_cmd.returns_num_rows() { - parsed_cmd.tag(Some(count)) - } else if parsed_cmd.returns_rows_affected() { - parsed_cmd.tag(Some(conn.changes() as usize)) + trace!("done w/ rows, computing tag: {tag:?}"); + + let tag = if tag.returns_num_rows() { + tag.tag(Some(count)) + } else if tag.returns_rows_affected() { + tag.tag(Some(conn.changes() as usize)) } else { - parsed_cmd.tag(None) + tag.tag(None) }; // done! @@ -1197,13 +1323,13 @@ pub async fn start( )?; continue; } - println!("started IMPLICIT tx"); + trace!("started IMPLICIT tx"); open_tx = Some(OpenTx::Implicit); } // close the current implement tx first if matches!(open_tx, Some(OpenTx::Implicit)) && cmd.is_begin() { - println!("committing IMPLICIT tx"); + trace!("committing IMPLICIT tx"); open_tx = None; if let Err(e) = handle_commit(&agent, &conn, "COMMIT") { @@ -1224,7 +1350,7 @@ pub async fn start( continue 'outer; } - println!("committed IMPLICIT tx"); + trace!("committed IMPLICIT tx"); } let count = if cmd.is_commit() { @@ -1379,12 +1505,14 @@ pub async fn start( Some(count) }; - let tag = if cmd.returns_num_rows() { - cmd.tag(count) - } else if cmd.returns_rows_affected() { - cmd.tag(Some(conn.changes() as usize)) + let tag = cmd.tag(); + + let tag = if tag.returns_num_rows() { + tag.tag(count) + } else if tag.returns_rows_affected() { + tag.tag(Some(conn.changes() as usize)) } else { - cmd.tag(None) + tag.tag(None) }; back_tx.blocking_send( @@ -1393,11 +1521,11 @@ pub async fn start( )?; if cmd.is_begin() { - println!("setting EXPLICIT tx"); + trace!("setting EXPLICIT tx"); // explicit tx open_tx = Some(OpenTx::Explicit) } else if cmd.is_rollback() || cmd.is_commit() { - println!("clearing current open tx"); + trace!("clearing current open tx"); // if this was a rollback, remove the current open tx open_tx = None; } @@ -1405,7 +1533,7 @@ pub async fn start( // automatically commit an implicit tx if matches!(open_tx, Some(OpenTx::Implicit)) { - println!("committing IMPLICIT tx"); + trace!("committing IMPLICIT tx"); open_tx = None; if let Err(e) = handle_commit(&agent, &conn, "COMMIT") { @@ -1425,7 +1553,7 @@ pub async fn start( )?; continue; } - println!("committed IMPLICIT tx"); + trace!("committed IMPLICIT tx"); } let ready_status = if open_tx.is_some() { @@ -1466,9 +1594,7 @@ pub async fn start( // statement b'S' => { if prepared.remove(name).is_some() { - portals.retain(|_, (stmt_name, _, _, _)| { - stmt_name.as_str() != name - }); + portals.retain(|_, portal| portal.stmt_name() != name); } back_tx.blocking_send( ( @@ -1830,16 +1956,16 @@ fn extract_param( SqliteName::Name(_) => {} SqliteName::Qualified(tbl_name, col_name) | SqliteName::DoublyQualified(_, tbl_name, col_name) => { - println!("looking tbl {} for col {}", tbl_name.0, col_name.0); + trace!("looking tbl {} for col {}", tbl_name.0, col_name.0); if let Some(table) = tables.get(&tbl_name.0) { - println!("found table! {}", table.name); - if let Ok(unquoted) = enquote::unquote(&col_name.0) { - println!("unquoted column as: {unquoted}"); - if let Some(col) = table.columns.get(&unquoted) { - params.push(col.sql_type); - } - } else if let Some(col) = table.columns.get(&col_name.0) { - println!("could not unquote, using original"); + trace!("found table! {}", table.name); + let col_name = if col_name.0.starts_with('"') { + rem_first_and_last(&col_name.0) + } else { + &col_name.0 + }; + + if let Some(col) = table.columns.get(col_name) { params.push(col.sql_type); } } @@ -1891,7 +2017,7 @@ fn extract_param( if let Some(rhs) = rhs { for expr in rhs.iter() { if let Some(name) = handle_lhs_rhs(lhs, expr) { - println!("HANDLED LHS RHS: {name:?}"); + trace!("HANDLED LHS RHS: {name:?}"); match name { // not aliased! SqliteName::Id(id) => { @@ -1983,8 +2109,15 @@ fn extract_param( } } +fn rem_first_and_last(value: &str) -> &str { + let mut chars = value.chars(); + chars.next(); + chars.next_back(); + chars.as_str() +} + fn handle_select(schema: &Schema, select: &Select, params: &mut Vec) { - match &select.body.select { + let tables = match &select.body.select { OneSelect::Select { columns, from, @@ -1993,13 +2126,16 @@ fn handle_select(schema: &Schema, select: &Select, params: &mut Vec) group_by: _, window_clause: _, } => { - if let Some(from) = from { + let tables = if let Some(from) = from { let tables = handle_from(schema, from, params); if let Some(where_clause) = where_clause { - println!("WHERE CLAUSE: {where_clause:?}"); + trace!("WHERE CLAUSE: {where_clause:?}"); extract_param(schema, where_clause, &tables, params); } - } + tables + } else { + HashMap::new() + }; for col in columns.iter() { if let ResultColumn::Expr(expr, _) = col { // TODO: check against table if we can... @@ -2008,6 +2144,7 @@ fn handle_select(schema: &Schema, select: &Select, params: &mut Vec) } } } + tables } OneSelect::Values(values_values) => { for values in values_values.iter() { @@ -2017,6 +2154,23 @@ fn handle_select(schema: &Schema, select: &Select, params: &mut Vec) } } } + HashMap::new() + } + }; + if let Some(limit) = &select.limit { + if is_param(&limit.expr) { + trace!("limit was a param (variable), pushing Integer type"); + params.push(SqliteType::Integer); + } else { + extract_param(schema, &limit.expr, &tables, params); + } + if let Some(offset) = &limit.offset { + if is_param(offset) { + trace!("offset was a param (variable), pushing Integer type"); + params.push(SqliteType::Integer); + } else { + extract_param(schema, offset, &tables, params); + } } } } @@ -2030,13 +2184,13 @@ fn handle_from<'a>( if let Some(select) = from.select.as_deref() { match select { SelectTable::Table(qname, maybe_alias, _) => { - let maybe_table = if let Ok(unquoted) = enquote::unquote(&qname.name.0) { - schema.tables.get(&unquoted) + let actual_tbl_name = if qname.name.0.starts_with('"') { + rem_first_and_last(&qname.name.0) } else { - schema.tables.get(&qname.name.0) + &qname.name.0 }; - if let Some(table) = maybe_table { + if let Some(table) = schema.tables.get(actual_tbl_name) { if let Some(alias) = maybe_alias { let alias = match alias { As::As(name) | As::Elided(name) => name.0.clone(), @@ -2058,13 +2212,13 @@ fn handle_from<'a>( for join in joins.iter() { match &join.table { SelectTable::Table(qname, maybe_alias, _) => { - let maybe_table = if let Ok(unquoted) = enquote::unquote(&qname.name.0) { - schema.tables.get(&unquoted) + let actual_tbl_name = if qname.name.0.starts_with('"') { + rem_first_and_last(&qname.name.0) } else { - schema.tables.get(&qname.name.0) + &qname.name.0 }; - if let Some(table) = maybe_table { + if let Some(table) = schema.tables.get(actual_tbl_name) { if let Some(alias) = maybe_alias { let alias = match alias { As::As(name) | As::Elided(name) => name.0.clone(), @@ -2086,88 +2240,90 @@ fn handle_from<'a>( tables } -fn parameter_types(schema: &Schema, stmt: &Stmt) -> Vec { +fn parameter_types(schema: &Schema, cmd: &Cmd) -> Vec { let mut params = vec![]; - match stmt { - Stmt::Select(select) => handle_select(schema, select, &mut params), - Stmt::Delete { - tbl_name, - where_clause: Some(where_clause), - .. - } => { - let mut tables = HashMap::new(); - if let Some(tbl) = schema.tables.get(&tbl_name.name.0) { - tables.insert(tbl_name.name.0.clone(), tbl); + if let Cmd::Stmt(stmt) = cmd { + match stmt { + Stmt::Select(select) => handle_select(schema, select, &mut params), + Stmt::Delete { + tbl_name, + where_clause: Some(where_clause), + .. + } => { + let mut tables = HashMap::new(); + if let Some(tbl) = schema.tables.get(&tbl_name.name.0) { + tables.insert(tbl_name.name.0.clone(), tbl); + } + extract_param(schema, where_clause, &tables, &mut params); } - extract_param(schema, where_clause, &tables, &mut params); - } - Stmt::Insert { - tbl_name, - columns, - body, - .. - } => { - println!("GOT AN INSERT TO {tbl_name:?} on columns: {columns:?} w/ body: {body:?}"); - if let Some(table) = schema.tables.get(&tbl_name.name.0) { - match body { - InsertBody::Select(select, _) => { - if let OneSelect::Values(values_values) = &select.body.select { - for values in values_values.iter() { - for (i, expr) in values.iter().enumerate() { - if is_param(expr) { - // specified columns - let col = if let Some(columns) = columns { - columns - .get(i) - .and_then(|name| table.columns.get(&name.0)) - } else { - table.columns.get_index(i).map(|(_name, col)| col) - }; - if let Some(col) = col { - params.push(col.sql_type); + Stmt::Insert { + tbl_name, + columns, + body, + .. + } => { + trace!("GOT AN INSERT TO {tbl_name:?} on columns: {columns:?} w/ body: {body:?}"); + if let Some(table) = schema.tables.get(&tbl_name.name.0) { + match body { + InsertBody::Select(select, _) => { + if let OneSelect::Values(values_values) = &select.body.select { + for values in values_values.iter() { + for (i, expr) in values.iter().enumerate() { + if is_param(expr) { + // specified columns + let col = if let Some(columns) = columns { + columns + .get(i) + .and_then(|name| table.columns.get(&name.0)) + } else { + table.columns.get_index(i).map(|(_name, col)| col) + }; + if let Some(col) = col { + params.push(col.sql_type); + } } } } + } else { + handle_select(schema, select, &mut params) } - } else { - handle_select(schema, select, &mut params) } - } - InsertBody::DefaultValues => { - // nothing to do! + InsertBody::DefaultValues => { + // nothing to do! + } } } } - } - Stmt::Update { - with: _, - or_conflict: _, - tbl_name, - indexed: _, - sets: _, - from, - where_clause, - returning: _, - order_by: _, - limit: _, - } => { - let mut tables = if let Some(from) = from { - handle_from(schema, from, &mut params) - } else { - Default::default() - }; - if let Some(tbl) = schema.tables.get(&tbl_name.name.0) { - tables.insert(tbl_name.name.0.clone(), tbl); + Stmt::Update { + with: _, + or_conflict: _, + tbl_name, + indexed: _, + sets: _, + from, + where_clause, + returning: _, + order_by: _, + limit: _, + } => { + let mut tables = if let Some(from) = from { + handle_from(schema, from, &mut params) + } else { + Default::default() + }; + if let Some(tbl) = schema.tables.get(&tbl_name.name.0) { + tables.insert(tbl_name.name.0.clone(), tbl); + } + if let Some(where_clause) = where_clause { + trace!("WHERE CLAUSE: {where_clause:?}"); + extract_param(schema, where_clause, &tables, &mut params); + } } - if let Some(where_clause) = where_clause { - println!("WHERE CLAUSE: {where_clause:?}"); - extract_param(schema, where_clause, &tables, &mut params); + _ => { + // do nothing, there can't be bound params here! } } - _ => { - // do nothing, there can't be bound params here! - } } params @@ -2282,7 +2438,7 @@ mod tests { println!("ROW: {row:?}"); let row = client - .query_one("SELECT t.id, t.text, t2.text as t2text FROM tests t LEFT JOIN tests2 t2 WHERE t.id = ?", &[&2i64]) + .query_one("SELECT t.id, t.text, t2.text as t2text FROM tests t LEFT JOIN tests2 t2 WHERE t.id = ? LIMIT ?", &[&2i64, &1i64]) .await?; println!("ROW: {row:?}"); From 7919470b45771865b8d06399631ebe498a5df783 Mon Sep 17 00:00:00 2001 From: Jerome Gravel-Niquet Date: Fri, 27 Oct 2023 11:55:08 -0400 Subject: [PATCH 11/12] update changelog and docs --- CHANGELOG.md | 2 ++ doc/SUMMARY.md | 1 + doc/api/pg.md | 15 +++++++++++++++ 3 files changed, 18 insertions(+) create mode 100644 doc/api/pg.md diff --git a/CHANGELOG.md b/CHANGELOG.md index cfa84f0d..fc41e49f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## Unreleased +- Implement a PostgreSQL wire protocol (v3) compatible API ([#83](../../pull/83)) +- Accept _all_ JSON types for SQLite params input ([#82](../../pull/82)) - Parallel synchronization w/ many deadlock and bug fixes ([#78](../../pull/78)) - Upgraded to cr-sqlite 0.16.0 (unreleased) ([#75](../../pull/75)) - Rewrite compaction logic to be more correct and efficient ([#74](../../pull/74)) diff --git a/doc/SUMMARY.md b/doc/SUMMARY.md index 00c2fb82..cea96337 100644 --- a/doc/SUMMARY.md +++ b/doc/SUMMARY.md @@ -22,6 +22,7 @@ - [POST /v1/transactions](api/transactions.md) - [POST /v1/queries](api/queries.md) - [POST /v1/subscriptions](api/subscriptions.md) + - [PostgreSQL Wire Protocol](api/pg.md) - [Command-line Interface](cli/README.md) - [agent](cli/agent.md) - [backup](cli/backup.md) diff --git a/doc/api/pg.md b/doc/api/pg.md new file mode 100644 index 00000000..6646c657 --- /dev/null +++ b/doc/api/pg.md @@ -0,0 +1,15 @@ +# PostgreSQL Wire Protocol v3 API (experimental) + +It's possible to configure a PostgreSQL wire protocol compatible API listener via the `api.pg.addr` setting. + +This is currently experimental, but it does work for most queries that are SQLite-flavored SQL. + +## What works + +- Read and write queries, parsable as SQLite-flavored SQL +- Most parameter bindings, but not all (work in progress) + +## Does not work + +- Any PostgreSQL-only SQL syntax +- Some placement of variable parameters (when binding) \ No newline at end of file From 448138f89cc4e9d711ce7cd84a06d88edee04674 Mon Sep 17 00:00:00 2001 From: Jerome Gravel-Niquet Date: Fri, 27 Oct 2023 12:14:48 -0400 Subject: [PATCH 12/12] remove unused crates in corro-pg --- Cargo.lock | 12 -- crates/corro-pg/Cargo.toml | 3 - crates/corro-pg/src/sql_state.rs | 332 ------------------------------- 3 files changed, 347 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e2bc472c..b2d73dde 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -830,16 +830,13 @@ dependencies = [ "compact_str 0.7.0", "corro-tests", "corro-types", - "enquote", "fallible-iterator", "futures", "pgwire", - "phf", "postgres-types", "rusqlite", "spawn", "sqlite3-parser", - "sqlparser", "tempfile", "thiserror", "time", @@ -3639,15 +3636,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "sqlparser" -version = "0.38.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0272b7bb0a225320170c99901b4b5fb3a4384e255a7f2cc228f61e2ba3893e75" -dependencies = [ - "log", -] - [[package]] name = "static_assertions" version = "1.1.0" diff --git a/crates/corro-pg/Cargo.toml b/crates/corro-pg/Cargo.toml index 6b7b1418..72a73b8c 100644 --- a/crates/corro-pg/Cargo.toml +++ b/crates/corro-pg/Cargo.toml @@ -10,12 +10,10 @@ corro-types = { path = "../corro-types" } fallible-iterator = { workspace = true } futures = { workspace = true } pgwire = { version = "0.16.1" } -phf = "*" postgres-types = { version = "0.2", features = ["with-time-0_3"] } rusqlite = { workspace = true } spawn = { path = "../spawn" } sqlite3-parser = { workspace = true } -sqlparser = { version = "0.38" } tempfile = { workspace = true } thiserror = { workspace = true } time = { workspace = true } @@ -23,7 +21,6 @@ tokio = { workspace = true } tokio-util = { workspace = true } tracing = { workspace = true } tripwire = { path = "../tripwire" } -enquote = { workspace = true } [dev-dependencies] corro-tests = { path = "../corro-tests" } diff --git a/crates/corro-pg/src/sql_state.rs b/crates/corro-pg/src/sql_state.rs index d8300247..4ddf16ad 100644 --- a/crates/corro-pg/src/sql_state.rs +++ b/crates/corro-pg/src/sql_state.rs @@ -3,14 +3,6 @@ pub struct SqlState(Inner); impl SqlState { - /// Creates a `SqlState` from its error code. - pub fn from_code(s: &str) -> SqlState { - match SQLSTATE_MAP.get(s) { - Some(state) => state.clone(), - None => SqlState(Inner::Other(s.into())), - } - } - /// Returns the error code corresponding to the `SqlState`. pub fn code(&self) -> &str { match &self.0 { @@ -274,7 +266,6 @@ impl SqlState { Inner::EXX000 => "XX000", Inner::EXX001 => "XX001", Inner::EXX002 => "XX002", - Inner::Other(code) => code, } } @@ -1342,327 +1333,4 @@ enum Inner { EXX000, EXX001, EXX002, - Other(Box), } - -#[rustfmt::skip] -static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = -::phf::Map { - key: 12913932095322966823, - disps: &[ - (0, 24), - (0, 12), - (0, 74), - (0, 109), - (0, 11), - (0, 9), - (0, 0), - (4, 38), - (3, 155), - (0, 6), - (1, 242), - (0, 66), - (0, 53), - (5, 180), - (3, 221), - (7, 230), - (0, 125), - (1, 46), - (0, 11), - (1, 2), - (0, 5), - (0, 13), - (0, 171), - (0, 15), - (0, 4), - (0, 22), - (1, 85), - (0, 75), - (2, 0), - (1, 25), - (7, 47), - (0, 45), - (0, 35), - (0, 7), - (7, 124), - (0, 0), - (14, 104), - (1, 183), - (61, 50), - (3, 76), - (0, 12), - (0, 7), - (4, 189), - (0, 1), - (64, 102), - (0, 0), - (16, 192), - (24, 19), - (0, 5), - (0, 87), - (0, 89), - (0, 14), - ], - entries: &[ - ("2F000", SqlState::SQL_ROUTINE_EXCEPTION), - ("01008", SqlState::WARNING_IMPLICIT_ZERO_BIT_PADDING), - ("42501", SqlState::INSUFFICIENT_PRIVILEGE), - ("22000", SqlState::DATA_EXCEPTION), - ("0100C", SqlState::WARNING_DYNAMIC_RESULT_SETS_RETURNED), - ("2200N", SqlState::INVALID_XML_CONTENT), - ("40001", SqlState::T_R_SERIALIZATION_FAILURE), - ("28P01", SqlState::INVALID_PASSWORD), - ("38000", SqlState::EXTERNAL_ROUTINE_EXCEPTION), - ("25006", SqlState::READ_ONLY_SQL_TRANSACTION), - ("2203D", SqlState::TOO_MANY_JSON_ARRAY_ELEMENTS), - ("42P09", SqlState::AMBIGUOUS_ALIAS), - ("F0000", SqlState::CONFIG_FILE_ERROR), - ("42P18", SqlState::INDETERMINATE_DATATYPE), - ("40002", SqlState::T_R_INTEGRITY_CONSTRAINT_VIOLATION), - ("22009", SqlState::INVALID_TIME_ZONE_DISPLACEMENT_VALUE), - ("42P08", SqlState::AMBIGUOUS_PARAMETER), - ("08000", SqlState::CONNECTION_EXCEPTION), - ("25P01", SqlState::NO_ACTIVE_SQL_TRANSACTION), - ("22024", SqlState::UNTERMINATED_C_STRING), - ("55000", SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE), - ("25001", SqlState::ACTIVE_SQL_TRANSACTION), - ("03000", SqlState::SQL_STATEMENT_NOT_YET_COMPLETE), - ("42710", SqlState::DUPLICATE_OBJECT), - ("2D000", SqlState::INVALID_TRANSACTION_TERMINATION), - ("2200G", SqlState::MOST_SPECIFIC_TYPE_MISMATCH), - ("22022", SqlState::INDICATOR_OVERFLOW), - ("55006", SqlState::OBJECT_IN_USE), - ("53200", SqlState::OUT_OF_MEMORY), - ("22012", SqlState::DIVISION_BY_ZERO), - ("P0002", SqlState::NO_DATA_FOUND), - ("XX001", SqlState::DATA_CORRUPTED), - ("22P05", SqlState::UNTRANSLATABLE_CHARACTER), - ("40003", SqlState::T_R_STATEMENT_COMPLETION_UNKNOWN), - ("22021", SqlState::CHARACTER_NOT_IN_REPERTOIRE), - ("25000", SqlState::INVALID_TRANSACTION_STATE), - ("42P15", SqlState::INVALID_SCHEMA_DEFINITION), - ("0B000", SqlState::INVALID_TRANSACTION_INITIATION), - ("22004", SqlState::NULL_VALUE_NOT_ALLOWED), - ("42804", SqlState::DATATYPE_MISMATCH), - ("42803", SqlState::GROUPING_ERROR), - ("02001", SqlState::NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED), - ("25002", SqlState::BRANCH_TRANSACTION_ALREADY_ACTIVE), - ("28000", SqlState::INVALID_AUTHORIZATION_SPECIFICATION), - ("HV009", SqlState::FDW_INVALID_USE_OF_NULL_POINTER), - ("22P01", SqlState::FLOATING_POINT_EXCEPTION), - ("2B000", SqlState::DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST), - ("42723", SqlState::DUPLICATE_FUNCTION), - ("21000", SqlState::CARDINALITY_VIOLATION), - ("0Z002", SqlState::STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER), - ("23505", SqlState::UNIQUE_VIOLATION), - ("HV00J", SqlState::FDW_OPTION_NAME_NOT_FOUND), - ("23P01", SqlState::EXCLUSION_VIOLATION), - ("39P03", SqlState::E_R_I_E_EVENT_TRIGGER_PROTOCOL_VIOLATED), - ("42P10", SqlState::INVALID_COLUMN_REFERENCE), - ("2202H", SqlState::INVALID_TABLESAMPLE_ARGUMENT), - ("55P04", SqlState::UNSAFE_NEW_ENUM_VALUE_USAGE), - ("P0000", SqlState::PLPGSQL_ERROR), - ("2F005", SqlState::S_R_E_FUNCTION_EXECUTED_NO_RETURN_STATEMENT), - ("HV00M", SqlState::FDW_UNABLE_TO_CREATE_REPLY), - ("0A000", SqlState::FEATURE_NOT_SUPPORTED), - ("24000", SqlState::INVALID_CURSOR_STATE), - ("25008", SqlState::HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL), - ("01003", SqlState::WARNING_NULL_VALUE_ELIMINATED_IN_SET_FUNCTION), - ("42712", SqlState::DUPLICATE_ALIAS), - ("HV014", SqlState::FDW_TOO_MANY_HANDLES), - ("58030", SqlState::IO_ERROR), - ("2201W", SqlState::INVALID_ROW_COUNT_IN_LIMIT_CLAUSE), - ("22033", SqlState::INVALID_SQL_JSON_SUBSCRIPT), - ("2BP01", SqlState::DEPENDENT_OBJECTS_STILL_EXIST), - ("HV005", SqlState::FDW_COLUMN_NAME_NOT_FOUND), - ("25004", SqlState::INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION), - ("54000", SqlState::PROGRAM_LIMIT_EXCEEDED), - ("20000", SqlState::CASE_NOT_FOUND), - ("2203G", SqlState::SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE), - ("22038", SqlState::SINGLETON_SQL_JSON_ITEM_REQUIRED), - ("22007", SqlState::INVALID_DATETIME_FORMAT), - ("08004", SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION), - ("2200H", SqlState::SEQUENCE_GENERATOR_LIMIT_EXCEEDED), - ("HV00D", SqlState::FDW_INVALID_OPTION_NAME), - ("P0004", SqlState::ASSERT_FAILURE), - ("22018", SqlState::INVALID_CHARACTER_VALUE_FOR_CAST), - ("0L000", SqlState::INVALID_GRANTOR), - ("22P04", SqlState::BAD_COPY_FILE_FORMAT), - ("22031", SqlState::INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION), - ("01P01", SqlState::WARNING_DEPRECATED_FEATURE), - ("0LP01", SqlState::INVALID_GRANT_OPERATION), - ("58P02", SqlState::DUPLICATE_FILE), - ("26000", SqlState::INVALID_SQL_STATEMENT_NAME), - ("54001", SqlState::STATEMENT_TOO_COMPLEX), - ("22010", SqlState::INVALID_INDICATOR_PARAMETER_VALUE), - ("HV00C", SqlState::FDW_INVALID_OPTION_INDEX), - ("22008", SqlState::DATETIME_FIELD_OVERFLOW), - ("42P06", SqlState::DUPLICATE_SCHEMA), - ("25007", SqlState::SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED), - ("42P20", SqlState::WINDOWING_ERROR), - ("HV091", SqlState::FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER), - ("HV021", SqlState::FDW_INCONSISTENT_DESCRIPTOR_INFORMATION), - ("42702", SqlState::AMBIGUOUS_COLUMN), - ("02000", SqlState::NO_DATA), - ("54011", SqlState::TOO_MANY_COLUMNS), - ("HV004", SqlState::FDW_INVALID_DATA_TYPE), - ("01006", SqlState::WARNING_PRIVILEGE_NOT_REVOKED), - ("42701", SqlState::DUPLICATE_COLUMN), - ("08P01", SqlState::PROTOCOL_VIOLATION), - ("42622", SqlState::NAME_TOO_LONG), - ("P0003", SqlState::TOO_MANY_ROWS), - ("22003", SqlState::NUMERIC_VALUE_OUT_OF_RANGE), - ("42P03", SqlState::DUPLICATE_CURSOR), - ("23001", SqlState::RESTRICT_VIOLATION), - ("57000", SqlState::OPERATOR_INTERVENTION), - ("22027", SqlState::TRIM_ERROR), - ("42P12", SqlState::INVALID_DATABASE_DEFINITION), - ("3B000", SqlState::SAVEPOINT_EXCEPTION), - ("2201B", SqlState::INVALID_REGULAR_EXPRESSION), - ("22030", SqlState::DUPLICATE_JSON_OBJECT_KEY_VALUE), - ("2F004", SqlState::S_R_E_READING_SQL_DATA_NOT_PERMITTED), - ("428C9", SqlState::GENERATED_ALWAYS), - ("2200S", SqlState::INVALID_XML_COMMENT), - ("22039", SqlState::SQL_JSON_ARRAY_NOT_FOUND), - ("42809", SqlState::WRONG_OBJECT_TYPE), - ("2201X", SqlState::INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE), - ("39001", SqlState::E_R_I_E_INVALID_SQLSTATE_RETURNED), - ("25P02", SqlState::IN_FAILED_SQL_TRANSACTION), - ("0P000", SqlState::INVALID_ROLE_SPECIFICATION), - ("HV00N", SqlState::FDW_UNABLE_TO_ESTABLISH_CONNECTION), - ("53100", SqlState::DISK_FULL), - ("42601", SqlState::SYNTAX_ERROR), - ("23000", SqlState::INTEGRITY_CONSTRAINT_VIOLATION), - ("HV006", SqlState::FDW_INVALID_DATA_TYPE_DESCRIPTORS), - ("HV00B", SqlState::FDW_INVALID_HANDLE), - ("HV00Q", SqlState::FDW_SCHEMA_NOT_FOUND), - ("01000", SqlState::WARNING), - ("42883", SqlState::UNDEFINED_FUNCTION), - ("57P01", SqlState::ADMIN_SHUTDOWN), - ("22037", SqlState::NON_UNIQUE_KEYS_IN_A_JSON_OBJECT), - ("00000", SqlState::SUCCESSFUL_COMPLETION), - ("55P03", SqlState::LOCK_NOT_AVAILABLE), - ("42P01", SqlState::UNDEFINED_TABLE), - ("42830", SqlState::INVALID_FOREIGN_KEY), - ("22005", SqlState::ERROR_IN_ASSIGNMENT), - ("22025", SqlState::INVALID_ESCAPE_SEQUENCE), - ("XX002", SqlState::INDEX_CORRUPTED), - ("42P16", SqlState::INVALID_TABLE_DEFINITION), - ("55P02", SqlState::CANT_CHANGE_RUNTIME_PARAM), - ("22019", SqlState::INVALID_ESCAPE_CHARACTER), - ("P0001", SqlState::RAISE_EXCEPTION), - ("72000", SqlState::SNAPSHOT_TOO_OLD), - ("42P11", SqlState::INVALID_CURSOR_DEFINITION), - ("40P01", SqlState::T_R_DEADLOCK_DETECTED), - ("57P02", SqlState::CRASH_SHUTDOWN), - ("HV00A", SqlState::FDW_INVALID_STRING_FORMAT), - ("2F002", SqlState::S_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED), - ("23503", SqlState::FOREIGN_KEY_VIOLATION), - ("40000", SqlState::TRANSACTION_ROLLBACK), - ("22032", SqlState::INVALID_JSON_TEXT), - ("2202E", SqlState::ARRAY_ELEMENT_ERROR), - ("42P19", SqlState::INVALID_RECURSION), - ("42611", SqlState::INVALID_COLUMN_DEFINITION), - ("42P13", SqlState::INVALID_FUNCTION_DEFINITION), - ("25003", SqlState::INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION), - ("39P02", SqlState::E_R_I_E_SRF_PROTOCOL_VIOLATED), - ("XX000", SqlState::INTERNAL_ERROR), - ("08006", SqlState::CONNECTION_FAILURE), - ("57P04", SqlState::DATABASE_DROPPED), - ("42P07", SqlState::DUPLICATE_TABLE), - ("22P03", SqlState::INVALID_BINARY_REPRESENTATION), - ("22035", SqlState::NO_SQL_JSON_ITEM), - ("42P14", SqlState::INVALID_PSTATEMENT_DEFINITION), - ("01007", SqlState::WARNING_PRIVILEGE_NOT_GRANTED), - ("38004", SqlState::E_R_E_READING_SQL_DATA_NOT_PERMITTED), - ("42P21", SqlState::COLLATION_MISMATCH), - ("0Z000", SqlState::DIAGNOSTICS_EXCEPTION), - ("HV001", SqlState::FDW_OUT_OF_MEMORY), - ("0F000", SqlState::LOCATOR_EXCEPTION), - ("22013", SqlState::INVALID_PRECEDING_OR_FOLLOWING_SIZE), - ("2201E", SqlState::INVALID_ARGUMENT_FOR_LOG), - ("22011", SqlState::SUBSTRING_ERROR), - ("42602", SqlState::INVALID_NAME), - ("01004", SqlState::WARNING_STRING_DATA_RIGHT_TRUNCATION), - ("42P02", SqlState::UNDEFINED_PARAMETER), - ("2203C", SqlState::SQL_JSON_OBJECT_NOT_FOUND), - ("HV002", SqlState::FDW_DYNAMIC_PARAMETER_VALUE_NEEDED), - ("0F001", SqlState::L_E_INVALID_SPECIFICATION), - ("58P01", SqlState::UNDEFINED_FILE), - ("38001", SqlState::E_R_E_CONTAINING_SQL_NOT_PERMITTED), - ("42703", SqlState::UNDEFINED_COLUMN), - ("57P05", SqlState::IDLE_SESSION_TIMEOUT), - ("57P03", SqlState::CANNOT_CONNECT_NOW), - ("HV007", SqlState::FDW_INVALID_COLUMN_NAME), - ("22014", SqlState::INVALID_ARGUMENT_FOR_NTILE), - ("22P06", SqlState::NONSTANDARD_USE_OF_ESCAPE_CHARACTER), - ("2203F", SqlState::SQL_JSON_SCALAR_REQUIRED), - ("2200F", SqlState::ZERO_LENGTH_CHARACTER_STRING), - ("09000", SqlState::TRIGGERED_ACTION_EXCEPTION), - ("2201F", SqlState::INVALID_ARGUMENT_FOR_POWER_FUNCTION), - ("08003", SqlState::CONNECTION_DOES_NOT_EXIST), - ("38002", SqlState::E_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED), - ("F0001", SqlState::LOCK_FILE_EXISTS), - ("42P22", SqlState::INDETERMINATE_COLLATION), - ("2200C", SqlState::INVALID_USE_OF_ESCAPE_CHARACTER), - ("2203E", SqlState::TOO_MANY_JSON_OBJECT_MEMBERS), - ("23514", SqlState::CHECK_VIOLATION), - ("22P02", SqlState::INVALID_TEXT_REPRESENTATION), - ("54023", SqlState::TOO_MANY_ARGUMENTS), - ("2200T", SqlState::INVALID_XML_PROCESSING_INSTRUCTION), - ("22016", SqlState::INVALID_ARGUMENT_FOR_NTH_VALUE), - ("25P03", SqlState::IDLE_IN_TRANSACTION_SESSION_TIMEOUT), - ("3B001", SqlState::S_E_INVALID_SPECIFICATION), - ("08001", SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION), - ("22036", SqlState::NON_NUMERIC_SQL_JSON_ITEM), - ("3F000", SqlState::INVALID_SCHEMA_NAME), - ("39P01", SqlState::E_R_I_E_TRIGGER_PROTOCOL_VIOLATED), - ("22026", SqlState::STRING_DATA_LENGTH_MISMATCH), - ("42P17", SqlState::INVALID_OBJECT_DEFINITION), - ("22034", SqlState::MORE_THAN_ONE_SQL_JSON_ITEM), - ("HV000", SqlState::FDW_ERROR), - ("2200B", SqlState::ESCAPE_CHARACTER_CONFLICT), - ("HV008", SqlState::FDW_INVALID_COLUMN_NUMBER), - ("34000", SqlState::INVALID_CURSOR_NAME), - ("2201G", SqlState::INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION), - ("44000", SqlState::WITH_CHECK_OPTION_VIOLATION), - ("HV010", SqlState::FDW_FUNCTION_SEQUENCE_ERROR), - ("39004", SqlState::E_R_I_E_NULL_VALUE_NOT_ALLOWED), - ("22001", SqlState::STRING_DATA_RIGHT_TRUNCATION), - ("3D000", SqlState::INVALID_CATALOG_NAME), - ("25005", SqlState::NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION), - ("2200L", SqlState::NOT_AN_XML_DOCUMENT), - ("27000", SqlState::TRIGGERED_DATA_CHANGE_VIOLATION), - ("HV090", SqlState::FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH), - ("42939", SqlState::RESERVED_NAME), - ("58000", SqlState::SYSTEM_ERROR), - ("2200M", SqlState::INVALID_XML_DOCUMENT), - ("HV00L", SqlState::FDW_UNABLE_TO_CREATE_EXECUTION), - ("57014", SqlState::QUERY_CANCELED), - ("23502", SqlState::NOT_NULL_VIOLATION), - ("22002", SqlState::NULL_VALUE_NO_INDICATOR_PARAMETER), - ("HV00R", SqlState::FDW_TABLE_NOT_FOUND), - ("HV00P", SqlState::FDW_NO_SCHEMAS), - ("38003", SqlState::E_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED), - ("39000", SqlState::EXTERNAL_ROUTINE_INVOCATION_EXCEPTION), - ("22015", SqlState::INTERVAL_FIELD_OVERFLOW), - ("HV00K", SqlState::FDW_REPLY_HANDLE), - ("HV024", SqlState::FDW_INVALID_ATTRIBUTE_VALUE), - ("2200D", SqlState::INVALID_ESCAPE_OCTET), - ("08007", SqlState::TRANSACTION_RESOLUTION_UNKNOWN), - ("2F003", SqlState::S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED), - ("42725", SqlState::AMBIGUOUS_FUNCTION), - ("2203A", SqlState::SQL_JSON_MEMBER_NOT_FOUND), - ("42846", SqlState::CANNOT_COERCE), - ("42P04", SqlState::DUPLICATE_DATABASE), - ("42000", SqlState::SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION), - ("2203B", SqlState::SQL_JSON_NUMBER_NOT_FOUND), - ("42P05", SqlState::DUPLICATE_PSTATEMENT), - ("53300", SqlState::TOO_MANY_CONNECTIONS), - ("53400", SqlState::CONFIGURATION_LIMIT_EXCEEDED), - ("42704", SqlState::UNDEFINED_OBJECT), - ("2202G", SqlState::INVALID_TABLESAMPLE_REPEAT), - ("22023", SqlState::INVALID_PARAMETER_VALUE), - ("53000", SqlState::INSUFFICIENT_RESOURCES), - ], -};