diff --git a/Cargo.toml b/Cargo.toml index 3bc30e1..7427a2f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,10 +27,9 @@ documentation = "https://docs.rs/seamdb" [dependencies] anyhow = { version = "1.0", features = ["backtrace"] } -async-trait = "0.1.72" +async-trait = "0.1.81" bytesize = "1.2.0" compact_str = "0.7.1" -derivative = "2.2.0" hashlink = "0.8.3" rdkafka = "0.34.0" tokio = {version = "1.30.0", features = ["full"]} @@ -55,13 +54,25 @@ bytemuck = { version = "1.13.1", features = ["derive"] } futures = "0.3.28" thiserror = "1.0.48" tokio-stream = "0.1.15" -asyncs = "0.3.0" +datafusion = "43.0.0" +pgwire = "0.27.0" +derive-where = "1.2.7" +asyncs = { version = "0.3.0", features = ["tokio"] } async-io = "2.3.4" +bytes = "1.7.2" +pg_query = "5.1.1" +lazy_static = "1.5.0" +lazy-init = "0.5.1" +enum_dispatch = "0.3.13" +jiff = "0.1.15" +clap = { version = "4.5.23", features = ["derive"] } +tracing-appender = "0.2.3" +tracing-subscriber = { version = "0.3.19", features = ["tracing-log", "env-filter", "std"] } [dev-dependencies] assertor = "0.0.2" asyncs = { version = "0.3.0", features = ["test", "tokio"] } -env_logger = "0.10.0" +env_logger = "0.11.5" serial_test = "2.0.0" speculoos = "0.11.0" test-case = "3.1.0" @@ -69,6 +80,12 @@ test-log = "0.2.12" testcontainers = "0.14.0" tracing-test = "0.2.4" +[profile.dev] +lto = "thin" + +[profile.release] +lto = "thin" + [workspace] members = ["src/protos/build"] diff --git a/Makefile b/Makefile index 692e93d..9755350 100644 --- a/Makefile +++ b/Makefile @@ -35,7 +35,11 @@ lint: cargo clippy --no-deps -- -D clippy::all build: - cargo build --tests + cargo build --tests --bins +release: + cargo build --tests --bins --release test: cargo test +clean: + cargo clean diff --git a/src/bin/seamdbd.rs b/src/bin/seamdbd.rs new file mode 100644 index 0000000..aa34bed --- /dev/null +++ b/src/bin/seamdbd.rs @@ -0,0 +1,104 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::net::SocketAddr; +use std::time::Duration; + +use anyhow::{anyhow, Result}; +use clap::Parser; +use pgwire::tokio::process_socket; +use seamdb::cluster::{ClusterEnv, EtcdClusterMetaDaemon, EtcdNodeRegistry, NodeId}; +use seamdb::endpoint::{Endpoint, Params, ServiceUri}; +use seamdb::log::{KafkaLogFactory, LogManager, MemoryLogFactory}; +use seamdb::protos::TableDescriptor; +use seamdb::sql::postgres::PostgresqlHandlerFactory; +use seamdb::tablet::{TabletClient, TabletNode}; +use tokio::net::{TcpListener, TcpStream}; +use tracing::{info, instrument}; +use tracing_subscriber::prelude::*; +use tracing_subscriber::{fmt, EnvFilter}; + +async fn new_log_manager(uri: ServiceUri<'_>) -> Result { + match uri.scheme() { + "memory" => LogManager::new(MemoryLogFactory::new(), &MemoryLogFactory::ENDPOINT, &Params::default()).await, + "kafka" => LogManager::new(KafkaLogFactory {}, &uri.endpoint(), uri.params()).await, + scheme => Err(anyhow!("unsupported log schema: {}, supported: memory, kafka", scheme)), + } +} + +#[instrument(skip_all, fields(addr = %addr))] +async fn serve_connection(factory: PostgresqlHandlerFactory, stream: TcpStream, addr: SocketAddr) { + match process_socket(stream, None, factory).await { + Ok(_) => info!("connection terminated"), + Err(err) => info!("connection terminated: {err}"), + } +} + +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +pub struct Args { + /// Meta cluster uri to store cluster wide metadata, e.g. etcd://etcd-cluster/scope. + #[arg(long = "cluster.uri")] + cluster_uri: String, + /// Cluster name. + #[arg(long = "cluster.name", default_value = "seamdb")] + cluster_name: String, + /// Log cluster uri to store WAL logs, e.g. kafka://kafka-cluster. + #[arg(long = "log.uri")] + log_uri: String, + /// Port to serve PostgreSQL compatible SQL statements. + #[arg(long = "sql.postgresql.port", default_value_t = 5432)] + pgsql_port: u16, +} + +#[tokio::main] +async fn main() { + let (non_blocking, _guard) = tracing_appender::non_blocking(std::io::stdout()); + + tracing_subscriber::registry() + .with(fmt::layer().with_writer(non_blocking).with_level(true).with_file(true).with_line_number(true)) + .with(EnvFilter::from_default_env()) + .init(); + + let args = Args::parse(); + let cluster_uri = ServiceUri::parse(&args.cluster_uri).unwrap(); + let log_uri = ServiceUri::parse(&args.log_uri).unwrap(); + + let node_id = NodeId::new_random(); + info!("Starting node {node_id}"); + + let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap(); + let address = format!("http://{}", listener.local_addr().unwrap()); + let endpoint = Endpoint::try_from(address.as_str()).unwrap(); + let (nodes, lease) = + EtcdNodeRegistry::join(cluster_uri.clone(), node_id.clone(), Some(endpoint.to_owned())).await.unwrap(); + let log_manager = new_log_manager(log_uri).await.unwrap(); + let cluster_env = ClusterEnv::new(log_manager.into(), nodes).with_replicas(1); + let mut cluster_meta_handle = + EtcdClusterMetaDaemon::start(args.cluster_name, cluster_uri.clone(), cluster_env.clone()).await.unwrap(); + let descriptor_watcher = cluster_meta_handle.watch_descriptor(None).await.unwrap(); + let deployment_watcher = cluster_meta_handle.watch_deployment(None).await.unwrap(); + let cluster_env = cluster_env.with_descriptor(descriptor_watcher).with_deployment(deployment_watcher.monitor()); + let _node = TabletNode::start(node_id, listener, lease, cluster_env.clone()); + let client = TabletClient::new(cluster_env).scope(TableDescriptor::POSTGRESQL_DIALECT_PREFIX); + tokio::time::sleep(Duration::from_secs(20)).await; + + let factory = PostgresqlHandlerFactory::new(client); + let listener = TcpListener::bind(format!("0.0.0.0:{}", args.pgsql_port)).await.unwrap(); + info!("Listening on {} ...", listener.local_addr().unwrap()); + loop { + let (stream, addr) = listener.accept().await.unwrap(); + tokio::spawn(serve_connection(factory.clone(), stream, addr)); + } +} diff --git a/src/clock.rs b/src/clock.rs index 466f943..36a8dea 100644 --- a/src/clock.rs +++ b/src/clock.rs @@ -16,6 +16,7 @@ use std::ops::{Add, Sub}; use std::sync::Arc; use std::time::{Duration, SystemTime}; +use jiff::Timestamp as JiffTimestamp; use static_assertions::{assert_impl_all, assert_not_impl_any}; pub use crate::protos::Timestamp; @@ -88,7 +89,15 @@ impl SystemTimeClock { impl std::fmt::Display for Timestamp { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) + if let Some(sequence) = self.get_txn_sequence() { + return write!(f, "txn-seq-{}", sequence); + } + let ts = JiffTimestamp::new(self.seconds as i64, self.nanoseconds as i32).unwrap(); + if self.logical == 0 { + write!(f, "{}", ts) + } else { + write!(f, "{}-{}", ts, self.logical) + } } } @@ -101,6 +110,13 @@ impl Timestamp { self.seconds == 0 && self.nanoseconds == 0 && self.logical == 0 } + pub const fn get_txn_sequence(&self) -> Option { + match self.seconds & 0x8000000000000000 != 0 { + true => Some(self.seconds as u32), + false => None, + } + } + pub const fn txn_sequence(sequence: u32) -> Self { Self { seconds: 0x8000000000000000 + sequence as u64, nanoseconds: 0, logical: 0 } } diff --git a/src/cluster/etcd.rs b/src/cluster/etcd.rs index 62d7e8e..c2e6604 100644 --- a/src/cluster/etcd.rs +++ b/src/cluster/etcd.rs @@ -46,7 +46,7 @@ impl EtcdLease { } } -pub(super) enum EtcdHelper {} +pub enum EtcdHelper {} impl EtcdHelper { pub async fn connect(endpoint: Endpoint<'_>, params: &Params<'_>) -> Result { @@ -129,6 +129,8 @@ pub mod tests { container: Container<'static, GenericImage>, } + unsafe impl Send for EtcdContainer {} + impl EtcdContainer { pub fn uri(&self) -> ServiceUri<'static> { let cluster = format!("etcd://127.0.0.1:{}", self.container.get_host_port_ipv4(2379)); diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index ca850ee..c358e84 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -13,7 +13,7 @@ // limitations under the License. mod env; -mod etcd; +pub mod etcd; mod meta; mod node; diff --git a/src/endpoint.rs b/src/endpoint.rs index 4be5350..33bd6d1 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -443,7 +443,7 @@ pub struct Params<'a> { /// Owned version of [Params]. pub type OwnedParams = Params<'static>; -impl<'a> Params<'a> { +impl Params<'_> { fn new(map: LinkedHashMap) -> Self { Self { map, _marker: std::marker::PhantomData } } diff --git a/src/kv.rs b/src/kv.rs new file mode 100644 index 0000000..c03b681 --- /dev/null +++ b/src/kv.rs @@ -0,0 +1,264 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use lazy_init::Lazy; +use thiserror::Error; + +use crate::cluster::NodeId; +use crate::protos::{ShardId, TabletId, Temporal, Timestamp, TimestampedKeyValue, Uuid, Value}; +use crate::tablet::{TabletClient, TabletClientError}; + +type Result = std::result::Result; + +#[derive(Debug, Error)] +pub enum KvError { + #[error("cluster not ready")] + ClusterNotReady, + #[error("cluster not deployed")] + ClusterNotDeployed, + #[error("tablet {id} deployment not found")] + DeploymentNotFound { id: TabletId }, + #[error("tablet {id} not deployed")] + TabletNotDeployed { id: TabletId }, + #[error("node {node} not available")] + NodeNotAvailable { node: NodeId }, + #[error("node {node} not connectable: {message}")] + NodeNotConnectable { node: NodeId, message: String }, + #[error("{status}")] + GrpcError { status: tonic::Status }, + #[error("unexpected: {message}")] + UnexpectedError { message: String }, + #[error("tablet {tablet_id} shard {shard_id} contains not shard for {key:?}")] + ShardNotFound { tablet_id: TabletId, shard_id: ShardId, key: Vec }, + #[error("data corruption: {message}")] + DataCorruption { message: String }, + #[error("key {key:?} already exist")] + KeyAlreadyExists { key: Vec }, + #[error("key {key:?} get overwritten at {actual}")] + KeyTimestampMismatch { key: Vec, actual: Timestamp }, + #[error("invalid argument: {message}")] + InvalidArgument { message: String }, + #[error("txn {txn_id} restarted from epoch {from_epoch} to {to_epoch}")] + TxnRestarted { txn_id: Uuid, from_epoch: u32, to_epoch: u32 }, + #[error("txn {txn_id} aborted in epoch {epoch}")] + TxnAborted { txn_id: Uuid, epoch: u32 }, + #[error("txn {txn_id} already committed with epoch {epoch}")] + TxnCommitted { txn_id: Uuid, epoch: u32 }, + #[error(transparent)] + Internal(#[from] anyhow::Error), +} + +impl KvError { + pub fn unexpected(message: impl Into) -> Self { + Self::UnexpectedError { message: message.into() } + } + + pub fn corrupted(message: impl Into) -> Self { + Self::DataCorruption { message: message.into() } + } + + pub fn node_not_available(node: impl Into) -> Self { + Self::NodeNotAvailable { node: node.into() } + } + + pub fn invalid_argument(message: impl Into) -> Self { + Self::InvalidArgument { message: message.into() } + } +} + +impl From for KvError { + fn from(err: TabletClientError) -> Self { + match err { + TabletClientError::ClusterNotReady => Self::ClusterNotReady, + TabletClientError::ClusterNotDeployed => Self::ClusterNotDeployed, + TabletClientError::DeploymentNotFound { id } => Self::DeploymentNotFound { id }, + TabletClientError::TabletNotDeployed { id } => Self::TabletNotDeployed { id }, + TabletClientError::NodeNotAvailable { node } => Self::NodeNotAvailable { node }, + TabletClientError::NodeNotConnectable { node, message } => Self::NodeNotConnectable { node, message }, + TabletClientError::GrpcError { status } => Self::GrpcError { status }, + TabletClientError::UnexpectedError { message } => Self::UnexpectedError { message }, + TabletClientError::ShardNotFound { tablet_id, shard_id, key } => { + Self::ShardNotFound { tablet_id, shard_id, key } + }, + TabletClientError::DataCorruption { message } => Self::DataCorruption { message }, + TabletClientError::KeyAlreadyExists { key } => Self::KeyAlreadyExists { key }, + TabletClientError::KeyTimestampMismatch { key, actual } => Self::KeyTimestampMismatch { key, actual }, + TabletClientError::InvalidArgument { message } => Self::InvalidArgument { message }, + TabletClientError::Internal(err) => Self::Internal(err), + } + } +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub enum KvSemantics { + Snapshot, + Transactional, + Inconsistent, +} + +#[async_trait::async_trait] +pub trait KvClient: Send + Sync { + async fn get(&self, key: &[u8]) -> Result>; + async fn scan(&self, start: &[u8], end: &[u8], limit: u32) -> Result<(Vec, Vec)>; + + async fn put(&self, key: &[u8], value: Option, expect_ts: Option) -> Result; + async fn increment(&self, key: &[u8], increment: i64) -> Result; + + fn client(&self) -> &TabletClient; + + fn semantics(&self) -> KvSemantics; + + async fn commit(&self) -> Result { + Err(KvError::unexpected("not supported".to_string())) + } + + async fn abort(&self) -> Result<()> { + Err(KvError::unexpected("not supported".to_string())) + } + + fn restart(&self) -> Result<()> { + Err(KvError::unexpected("not supported".to_string())) + } +} + +#[async_trait::async_trait] +impl KvClient for TabletClient { + fn client(&self) -> &TabletClient { + self + } + + fn semantics(&self) -> KvSemantics { + KvSemantics::Inconsistent + } + + async fn get(&self, key: &[u8]) -> Result> { + match self.get_directly(Temporal::default(), key, 0).await? { + (_temporal, None) => Err(KvError::unexpected("timestamp get get no response")), + (_temporal, Some(response)) => Ok(response.value.map(|v| v.into_parts())), + } + } + + async fn scan(&self, start: &[u8], end: &[u8], limit: u32) -> Result<(Vec, Vec)> { + match self.scan_directly(Temporal::default(), start, end, limit).await? { + (_temporal, None) => Err(KvError::unexpected("no scan response")), + (_temporal, Some(response)) => Ok(response), + } + } + + async fn put(&self, key: &[u8], value: Option, expect_ts: Option) -> Result { + match self.put_directly(Temporal::default(), key, value, expect_ts, 0).await? { + (_temporal, None) => Err(KvError::unexpected("timestamp put get aborted")), + (_temporal, Some(write_ts)) => Ok(write_ts), + } + } + + async fn increment(&self, key: &[u8], increment: i64) -> Result { + match self.increment_directly(Temporal::default(), key, increment, 0).await? { + (_temporal, None) => Err(KvError::unexpected("timestamp increment get aborted")), + (_temporal, Some(incremented)) => Ok(incremented), + } + } +} + +pub struct LazyInitTimestampedKvClient { + client: TabletClient, + lazy: Lazy, +} + +impl LazyInitTimestampedKvClient { + pub fn new(client: TabletClient) -> Self { + Self { client, lazy: Lazy::new() } + } + + fn client(&self) -> &TimestampedKvClient { + self.lazy.get_or_create(|| TimestampedKvClient::new(self.client.clone(), self.client.now())) + } +} + +#[async_trait::async_trait] +impl KvClient for LazyInitTimestampedKvClient { + fn client(&self) -> &TabletClient { + &self.client + } + + fn semantics(&self) -> KvSemantics { + KvSemantics::Snapshot + } + + async fn get(&self, key: &[u8]) -> Result> { + self.client().get(key).await + } + + async fn scan(&self, start: &[u8], end: &[u8], limit: u32) -> Result<(Vec, Vec)> { + self.client().scan(start, end, limit).await + } + + async fn put(&self, key: &[u8], value: Option, expect_ts: Option) -> Result { + self.client().put(key, value, expect_ts).await + } + + async fn increment(&self, key: &[u8], increment: i64) -> Result { + self.client().increment(key, increment).await + } +} + +pub struct TimestampedKvClient { + client: TabletClient, + timestamp: Timestamp, +} + +impl TimestampedKvClient { + pub fn new(client: TabletClient, timestamp: Timestamp) -> Self { + Self { client, timestamp } + } +} + +#[async_trait::async_trait] +impl KvClient for TimestampedKvClient { + fn client(&self) -> &TabletClient { + &self.client + } + + fn semantics(&self) -> KvSemantics { + KvSemantics::Snapshot + } + + async fn get(&self, key: &[u8]) -> Result> { + match self.client.get_directly(self.timestamp.into(), key, 0).await? { + (_temporal, None) => Err(KvError::unexpected("timestamp get get no response")), + (_temporal, Some(response)) => Ok(response.value.map(|r| r.into_parts())), + } + } + + async fn scan(&self, start: &[u8], end: &[u8], limit: u32) -> Result<(Vec, Vec)> { + match self.client.scan_directly(self.timestamp.into(), start, end, limit).await? { + (_temporal, None) => Err(KvError::unexpected("no scan response")), + (_temporal, Some(response)) => Ok(response), + } + } + + async fn put(&self, key: &[u8], value: Option, expect_ts: Option) -> Result { + match self.client.put_directly(self.timestamp.into(), key, value, expect_ts, 0).await? { + (_temporal, None) => Err(KvError::unexpected("timestamp put get aborted")), + (_temporal, Some(write_ts)) => Ok(write_ts), + } + } + + async fn increment(&self, key: &[u8], increment: i64) -> Result { + match self.client.increment_directly(self.timestamp.into(), key, increment, 0).await? { + (_temporal, None) => Err(KvError::unexpected("timestamp increment get aborted")), + (_temporal, Some(incremented)) => Ok(incremented), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 87a0ee6..30e8518 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,8 +18,10 @@ pub mod clock; pub mod cluster; pub mod endpoint; pub mod keys; +pub mod kv; pub mod log; pub mod protos; +pub mod sql; pub mod tablet; pub mod timer; pub mod txn; diff --git a/src/log/kafka.rs b/src/log/kafka.rs index e0fef3d..4e21dae 100644 --- a/src/log/kafka.rs +++ b/src/log/kafka.rs @@ -22,7 +22,7 @@ use anyhow::{anyhow, Error, Result}; use async_trait::async_trait; use bytesize::ByteSize; use compact_str::ToCompactString; -use derivative::Derivative; +use derive_where::derive_where; use rdkafka::admin::{AdminClient, AdminOptions, NewTopic, TopicReplication}; use rdkafka::client::DefaultClientContext; use rdkafka::config::{ClientConfig, FromClientConfig}; @@ -36,15 +36,14 @@ use tokio::time; use crate::endpoint::{Endpoint, Params}; use crate::log::{ByteLogProducer, ByteLogSubscriber, LogClient, LogFactory, LogOffset, LogPosition}; -#[derive(Derivative)] -#[derivative(Debug)] +#[derive_where(Debug)] struct KafkaPartitionProducer { topic: String, partition: i32, queue: VecDeque>, - #[derivative(Debug = "ignore")] + #[derive_where(skip(Debug))] producer: FutureProducer, - #[derivative(Debug = "ignore")] + #[derive_where(skip(Debug))] deliveries: VecDeque, } @@ -131,19 +130,17 @@ impl ByteLogProducer for KafkaPartitionProducer { } } -#[derive(Derivative)] -#[derivative(Debug)] +#[derive_where(Debug)] struct PartitionConsumer { topic: String, partition: i32, - #[derivative(Debug = "ignore")] + #[derive_where(skip(Debug))] consumer: StreamConsumer, } -#[derive(Derivative)] -#[derivative(Debug)] +#[derive_where(Debug)] struct KafkaPartitionConsumer { - #[derivative(Debug = "ignore")] + #[derive_where(skip(Debug))] message: Option>, consumer: Arc, } @@ -156,7 +153,7 @@ impl KafkaPartitionConsumer { #[async_trait] impl ByteLogSubscriber for KafkaPartitionConsumer { - async fn read(&mut self) -> Result<(LogPosition, &[u8])> { + async fn read<'a>(&'a mut self) -> Result<(LogPosition, &'a [u8])> { self.message = None; let message = self.consumer.consumer.recv().await?; let payload = message.payload().unwrap_or(Default::default()); @@ -190,11 +187,10 @@ impl ByteLogSubscriber for KafkaPartitionConsumer { } } -#[derive(Derivative)] -#[derivative(Debug)] +#[derive_where(Debug)] pub struct KafkaLogClient { config: ClientConfig, - #[derivative(Debug = "ignore")] + #[derive_where(skip(Debug))] client: AdminClient, replication: i32, } @@ -265,7 +261,8 @@ impl LogClient for KafkaLogClient { } async fn create_log(&self, name: &str, retention: ByteSize) -> Result<()> { - let retention_bytes = retention.0.to_compact_string(); + let retention_bytes = + if retention == ByteSize::default() { "-1".to_compact_string() } else { retention.0.to_compact_string() }; let topics = [new_topic_config(name, self.replication, &retention_bytes)]; let mut results = self.client.create_topics(&topics, &AdminOptions::default()).await?; let topic_result = results.pop().ok_or_else(|| anyhow!("no topic results in topic creation"))?; diff --git a/src/log/memory.rs b/src/log/memory.rs index ceedb5e..8052377 100644 --- a/src/log/memory.rs +++ b/src/log/memory.rs @@ -130,7 +130,7 @@ struct MemoryLogSubscriber { #[async_trait] impl ByteLogSubscriber for MemoryLogSubscriber { - async fn read(&mut self) -> Result<(LogPosition, &[u8])> { + async fn read<'a>(&'a mut self) -> Result<(LogPosition, &'a [u8])> { loop { match self.log.read(self.offset)? { Either::Left(data) => { diff --git a/src/log/mod.rs b/src/log/mod.rs index 331bd0c..af0d97f 100644 --- a/src/log/mod.rs +++ b/src/log/mod.rs @@ -179,7 +179,7 @@ pub trait ByteLogProducer: Send + std::fmt::Debug + 'static { /// Subscriber to read byte message from log. #[async_trait] pub trait ByteLogSubscriber: Send + Sync + std::fmt::Debug { - async fn read(&mut self) -> Result<(LogPosition, &[u8])>; + async fn read<'a>(&'a mut self) -> Result<(LogPosition, &'a [u8])>; async fn seek(&mut self, offset: LogOffset) -> Result<()>; diff --git a/src/protos/build/src/main.rs b/src/protos/build/src/main.rs index c7ccbfd..2bed308 100644 --- a/src/protos/build/src/main.rs +++ b/src/protos/build/src/main.rs @@ -24,7 +24,6 @@ fn main() { let protos: Vec<_> = protos_dir .read_dir() .unwrap() - .into_iter() .map(|entry| entry.unwrap().path()) .filter(|p| p.file_name().unwrap().to_str().unwrap().ends_with(".proto")) .collect(); @@ -37,15 +36,39 @@ fn main() { let mut config = prost_build::Config::new(); config .skip_debug(std::iter::once("Uuid")) - .type_attribute("Timestamp", "#[derive(Eq, PartialOrd, Ord)]") + .type_attribute("Timestamp", "#[derive(Eq, Hash, PartialOrd, Ord)]") .type_attribute("MessageId", "#[derive(Eq, PartialOrd, Ord)]") + .type_attribute("KeySpan", "#[derive(Eq, PartialOrd, Ord)]") + .type_attribute("KeyRange", "#[derive(Eq, PartialOrd, Ord)]") .type_attribute("Uuid", "#[derive(Eq, Hash, PartialOrd, Ord)]") + .type_attribute("ColumnDescriptor", "#[derive(Eq, Hash, PartialOrd)]") + .type_attribute("IndexDescriptor", "#[derive(Eq, Hash, PartialOrd)]") + .type_attribute("TableDescriptor", "#[derive(Eq, Hash, PartialOrd)]") + .type_attribute("ColumnValue", "#[derive(Eq, Hash, PartialOrd)]") + .type_attribute("ColumnTypeDeclaration", "#[derive(Eq, Hash, PartialOrd)]") + .type_attribute("NumericTypeDeclaration", "#[derive(Eq, Hash, PartialOrd)]") + .type_attribute("CharacterTypeDeclaration", "#[derive(Eq, Hash, PartialOrd)]") + .type_attribute("StoringFloat32", "#[derive(PartialOrd)]") + .type_attribute("StoringFloat64", "#[derive(PartialOrd)]") + .require_field("DescriptorMeta.timestamp") + .require_field("TableDescriptor.timestamp") + .require_field("SchemaDescriptor.timestamp") + .require_field("DatabaseDescriptor.timestamp") + .oneof_enum("ColumnTypeDeclaration") .oneof_enum("Value") .oneof_enum("Temporal") + .oneof_enum("Transient") .oneof_enum("DataRequest") + .oneof_enum("ColumnValue") .oneof_enum("DataResponse") + .oneof_enum("DataError") .enumerate_field(".seamdb") + .enumerate_field(".sql") .require_field(".seamdb.TabletWatermark") + .require_field("RefreshReadError.temporal") + .require_field("ConflictWriteError.transient") + .require_field("TimestampMismatchError.actual") + .require_field("BatchError.error") .require_field("ShardRequest.request") .require_field("ShardResponse.response") .require_field("BatchRequest.temporal") @@ -88,7 +111,7 @@ fn main() { file.write_all(b"\n").unwrap(); } file.write_all(b"#[rustfmt::skip]\n").unwrap(); - write!(&mut file, "mod {};\n", module).unwrap(); - write!(&mut file, "pub use self::{}::*;\n", module).unwrap(); + writeln!(&mut file, "mod {};", module).unwrap(); + writeln!(&mut file, "pub use self::{}::*;", module).unwrap(); } } diff --git a/src/protos/errors.rs b/src/protos/errors.rs new file mode 100644 index 0000000..7e9f61d --- /dev/null +++ b/src/protos/errors.rs @@ -0,0 +1,134 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::error::Error; +use std::fmt::{Display, Formatter, Result}; + +use super::*; + +impl Display for ConflictWriteError { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + match &self.transient { + Transient::Timestamp(ts) => write!(f, "conflict with write to key {:?} at {ts}", self.key), + Transient::Transaction(txn) => { + write!(f, "conflict with write to key {:?} from {}", self.key, txn) + }, + } + } +} + +impl Error for ConflictWriteError {} + +impl Display for DataTypeMismatchError { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!(f, "expect key {:?} has type {:?}, but get {:?}", self.key, self.expect, self.actual) + } +} + +impl Error for DataTypeMismatchError {} + +impl Display for ShardNotFoundError { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + match self.shard_id { + 0 => write!(f, "find no shard for key {:?}", self.key), + shard_id => write!(f, "can not find shard {} for key {:?}", shard_id, self.key), + } + } +} + +impl Error for ShardNotFoundError {} + +impl Display for TimestampMismatchError { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!(f, "key {:?} get overwritten at timestamp {}", self.key, self.actual) + } +} + +impl Error for TimestampMismatchError {} + +impl Display for StoreError { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + f.write_str(&self.message) + } +} + +impl Error for StoreError {} + +impl Display for SimpleError { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + f.write_str(&self.message) + } +} + +impl Error for SimpleError {} + +impl Display for DataError { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + match self { + Self::ConflictWrite(err) => err.fmt(f), + Self::MismatchDataType(err) => err.fmt(f), + Self::ShardNotFound(err) => err.fmt(f), + Self::TimestampMismatch(err) => err.fmt(f), + Self::Store(err) => err.fmt(f), + Self::Internal(err) => err.fmt(f), + } + } +} + +impl Error for DataError {} + +impl DataError { + pub fn shard_not_found(key: impl Into>, shard_id: ShardId) -> Self { + Self::ShardNotFound(ShardNotFoundError { key: key.into(), shard_id: shard_id.into() }) + } + + pub fn conflict_write(key: impl Into>, transient: impl Into) -> Self { + Self::ConflictWrite(ConflictWriteError { key: key.into(), transient: transient.into() }) + } + + pub fn timestamp_mismatch(key: impl Into>, actual: Timestamp) -> Self { + Self::TimestampMismatch(TimestampMismatchError { key: key.into(), actual }) + } + + pub fn internal(message: impl Into) -> Self { + Self::Internal(SimpleError { message: message.into() }) + } +} + +impl Display for BatchError { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + match self.request_index { + None => write!(f, "batch request failed due to {}", self.error), + Some(i) => { + write!(f, "{}th request in batch failed due to {}", i, self.error) + }, + } + } +} + +impl Error for BatchError {} + +impl BatchError { + pub fn new(error: DataError) -> Self { + Self { request_index: None, error } + } + + pub fn with_index(request_index: usize, error: DataError) -> Self { + Self { request_index: Some(request_index as u32), error } + } + + pub fn with_message(msg: impl Into) -> Self { + Self { request_index: None, error: DataError::Internal(SimpleError { message: msg.into() }) } + } +} diff --git a/src/protos/generated/mod.rs b/src/protos/generated/mod.rs index 5f67409..64472fb 100644 --- a/src/protos/generated/mod.rs +++ b/src/protos/generated/mod.rs @@ -3,3 +3,7 @@ #[rustfmt::skip] mod seamdb; pub use self::seamdb::*; + +#[rustfmt::skip] +mod sql; +pub use self::sql::*; diff --git a/src/protos/generated/seamdb.rs b/src/protos/generated/seamdb.rs index fe42df6..cf23f5d 100644 --- a/src/protos/generated/seamdb.rs +++ b/src/protos/generated/seamdb.rs @@ -19,6 +19,7 @@ pub struct MessageId { #[prost(uint64, tag = "2")] pub sequence: u64, } +#[derive(Eq, PartialOrd, Ord)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct KeyRange { @@ -27,6 +28,7 @@ pub struct KeyRange { #[prost(bytes = "vec", tag = "2")] pub end: ::prost::alloc::vec::Vec, } +#[derive(Eq, PartialOrd, Ord)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct KeySpan { @@ -146,7 +148,7 @@ pub struct TabletManifest { #[prost(string, repeated, tag = "35")] pub obsoleted_files: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, } -#[derive(Eq, PartialOrd, Ord)] +#[derive(Eq, Hash, PartialOrd, Ord)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Timestamp { @@ -198,6 +200,14 @@ pub struct KeyValue { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub enum Transient { + #[prost(message, tag = "1")] + Timestamp(Timestamp), + #[prost(message, tag = "2")] + Transaction(TxnMeta), +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub enum Temporal { #[prost(message, tag = "1")] Timestamp(Timestamp), @@ -574,8 +584,12 @@ pub struct RefreshReadRequest { pub from: Timestamp, } #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, Copy, PartialEq, ::prost::Message)] -pub struct RefreshReadResponse {} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RefreshReadResponse { + /// Empty if full span refreshed. + #[prost(bytes = "vec", tag = "1")] + pub resume_key: ::prost::alloc::vec::Vec, +} #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ParticipateTxnRequest { @@ -604,6 +618,76 @@ pub struct TabletHeartbeatResponse { #[prost(message, optional, tag = "1")] pub deployment: ::core::option::Option, } +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BatchError { + #[prost(uint32, optional, tag = "1")] + pub request_index: ::core::option::Option, + #[prost(message, required, tag = "2")] + pub error: DataError, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub enum DataError { + #[prost(message, tag = "1")] + ConflictWrite(ConflictWriteError), + #[prost(message, tag = "2")] + MismatchDataType(DataTypeMismatchError), + #[prost(message, tag = "3")] + ShardNotFound(ShardNotFoundError), + #[prost(message, tag = "4")] + TimestampMismatch(TimestampMismatchError), + #[prost(message, tag = "5")] + Store(StoreError), + #[prost(message, tag = "6")] + Internal(SimpleError), +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ConflictWriteError { + #[prost(bytes = "vec", tag = "1")] + pub key: ::prost::alloc::vec::Vec, + #[prost(message, required, tag = "2")] + pub transient: Transient, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DataTypeMismatchError { + #[prost(bytes = "vec", tag = "1")] + pub key: ::prost::alloc::vec::Vec, + #[prost(enumeration = "!ValueType", tag = "2")] + pub expect: ValueType, + #[prost(enumeration = "!ValueType", tag = "3")] + pub actual: ValueType, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TimestampMismatchError { + #[prost(bytes = "vec", tag = "1")] + pub key: ::prost::alloc::vec::Vec, + #[prost(message, required, tag = "2")] + pub actual: Timestamp, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ShardNotFoundError { + #[prost(bytes = "vec", tag = "1")] + pub key: ::prost::alloc::vec::Vec, + #[prost(uint64, tag = "2")] + pub shard_id: u64, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StoreError { + #[prost(string, tag = "1")] + pub message: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SimpleError { + #[prost(string, tag = "1")] + pub message: ::prost::alloc::string::String, +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum ShardMergeBounds { @@ -638,6 +722,41 @@ impl ShardMergeBounds { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum ValueType { + Absent = 0, + Int = 1, + Float = 2, + Bytes = 3, + String = 4, +} +impl ValueType { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + ValueType::Absent => "Absent", + ValueType::Int => "Int", + ValueType::Float => "Float", + ValueType::Bytes => "Bytes", + ValueType::String => "String", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "Absent" => Some(Self::Absent), + "Int" => Some(Self::Int), + "Float" => Some(Self::Float), + "Bytes" => Some(Self::Bytes), + "String" => Some(Self::String), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum TxnStatus { Pending = 0, /// Terminal status. diff --git a/src/protos/generated/sql.rs b/src/protos/generated/sql.rs new file mode 100644 index 0000000..f4fc549 --- /dev/null +++ b/src/protos/generated/sql.rs @@ -0,0 +1,211 @@ +// This file is @generated by prost-build. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DescriptorMeta { + #[prost(uint64, tag = "1")] + pub id: u64, + #[prost(string, tag = "2")] + pub name: ::prost::alloc::string::String, + #[prost(uint64, tag = "3")] + pub database_id: u64, + #[prost(uint64, tag = "4")] + pub parent_id: u64, + #[prost(message, required, tag = "5")] + pub timestamp: super::seamdb::Timestamp, + #[prost(bytes = "vec", tag = "6")] + pub blob: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DatabaseDescriptor { + #[prost(uint64, tag = "1")] + pub id: u64, + #[prost(string, tag = "2")] + pub name: ::prost::alloc::string::String, + /// Populated from kv layer. + #[prost(message, required, tag = "3")] + pub timestamp: super::seamdb::Timestamp, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SchemaDescriptor { + #[prost(uint64, tag = "1")] + pub id: u64, + #[prost(string, tag = "2")] + pub name: ::prost::alloc::string::String, + #[prost(uint64, tag = "3")] + pub database_id: u64, + /// Populated from kv layer. + #[prost(message, required, tag = "4")] + pub timestamp: super::seamdb::Timestamp, +} +#[derive(Eq, Hash, PartialOrd)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TableDescriptor { + #[prost(uint64, tag = "1")] + pub id: u64, + #[prost(string, tag = "2")] + pub name: ::prost::alloc::string::String, + #[prost(uint64, tag = "3")] + pub database_id: u64, + #[prost(uint64, tag = "4")] + pub schema_id: u64, + /// Populated from kv layer. + #[prost(message, required, tag = "5")] + pub timestamp: super::seamdb::Timestamp, + #[prost(uint32, tag = "7")] + pub last_column_id: u32, + #[prost(message, repeated, tag = "8")] + pub columns: ::prost::alloc::vec::Vec, + #[prost(uint32, tag = "9")] + pub last_index_id: u32, + #[prost(message, repeated, tag = "10")] + pub indices: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct ColumnType { + #[prost(enumeration = "!ColumnTypeKind", tag = "1")] + pub kind: ColumnTypeKind, + #[prost(message, optional, tag = "2")] + pub declaration: ::core::option::Option, +} +#[derive(Eq, Hash, PartialOrd)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct NumericTypeDeclaration { + #[prost(uint32, tag = "1")] + pub precision: u32, + #[prost(uint32, tag = "2")] + pub scale: u32, +} +#[derive(Eq, Hash, PartialOrd)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct CharacterTypeDeclaration { + #[prost(uint32, tag = "1")] + pub max_length: u32, +} +#[derive(Eq, Hash, PartialOrd)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub enum ColumnTypeDeclaration { + #[prost(message, tag = "1")] + Numeric(NumericTypeDeclaration), + #[prost(message, tag = "2")] + Character(CharacterTypeDeclaration), +} +#[derive(PartialOrd)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct StoringFloat32 { + #[prost(float, tag = "1")] + pub value: f32, +} +#[derive(PartialOrd)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct StoringFloat64 { + #[prost(double, tag = "1")] + pub value: f64, +} +#[derive(Eq, Hash, PartialOrd)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub enum ColumnValue { + #[prost(bool, tag = "1")] + Boolean(bool), + #[prost(int32, tag = "2")] + Int16(i32), + #[prost(int32, tag = "3")] + Int32(i32), + #[prost(int64, tag = "4")] + Int64(i64), + #[prost(message, tag = "5")] + Float32(StoringFloat32), + #[prost(message, tag = "6")] + Float64(StoringFloat64), + #[prost(bytes, tag = "7")] + Bytes(::prost::alloc::vec::Vec), + #[prost(string, tag = "8")] + String(::prost::alloc::string::String), +} +#[derive(Eq, Hash, PartialOrd)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ColumnDescriptor { + #[prost(uint32, tag = "1")] + pub id: u32, + #[prost(string, tag = "2")] + pub name: ::prost::alloc::string::String, + #[prost(bool, tag = "3")] + pub nullable: bool, + #[prost(bool, tag = "4")] + pub serial: bool, + #[prost(enumeration = "!ColumnTypeKind", tag = "5")] + pub type_kind: ColumnTypeKind, + #[prost(message, optional, tag = "6")] + pub type_declaration: ::core::option::Option, + #[prost(message, optional, tag = "7")] + pub default_value: ::core::option::Option, +} +#[derive(Eq, Hash, PartialOrd)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct IndexDescriptor { + #[prost(uint32, tag = "1")] + pub id: u32, + #[prost(string, tag = "2")] + pub name: ::prost::alloc::string::String, + #[prost(bool, tag = "3")] + pub unique: bool, + #[prost(uint32, repeated, tag = "4")] + pub column_ids: ::prost::alloc::vec::Vec, + #[prost(uint32, repeated, tag = "5")] + pub storing_column_ids: ::prost::alloc::vec::Vec, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum ColumnTypeKind { + Boolean = 0, + Int16 = 1, + Int32 = 2, + Int64 = 3, + Float32 = 4, + Float64 = 5, + Bytes = 6, + String = 7, +} +impl ColumnTypeKind { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + ColumnTypeKind::Boolean => "Boolean", + ColumnTypeKind::Int16 => "Int16", + ColumnTypeKind::Int32 => "Int32", + ColumnTypeKind::Int64 => "Int64", + ColumnTypeKind::Float32 => "Float32", + ColumnTypeKind::Float64 => "Float64", + ColumnTypeKind::Bytes => "Bytes", + ColumnTypeKind::String => "String", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "Boolean" => Some(Self::Boolean), + "Int16" => Some(Self::Int16), + "Int32" => Some(Self::Int32), + "Int64" => Some(Self::Int64), + "Float32" => Some(Self::Float32), + "Float64" => Some(Self::Float64), + "Bytes" => Some(Self::Bytes), + "String" => Some(Self::String), + _ => None, + } + } +} diff --git a/src/protos/mod.rs b/src/protos/mod.rs index 7707881..39fa466 100644 --- a/src/protos/mod.rs +++ b/src/protos/mod.rs @@ -16,7 +16,9 @@ #[rustfmt::skip] mod generated; +mod errors; mod span; +mod sql; mod temporal; mod uuid; @@ -30,6 +32,7 @@ pub use temporal::{HasTxnMeta, HasTxnStatus}; pub use self::data_message::Operation as DataOperation; pub use self::generated::*; pub use self::span::*; +pub use self::sql::*; pub use self::tablet_service_client::TabletServiceClient; pub use self::tablet_service_server::{TabletService, TabletServiceServer}; pub use crate::keys; @@ -382,6 +385,10 @@ impl TimestampedValue { pub fn read_bytes(&self, key: &[u8], operation: &str) -> Result<&[u8]> { self.value.read_bytes(key, operation) } + + pub fn into_parts(self) -> (Timestamp, Value) { + (self.timestamp, self.value) + } } impl TimestampedKeyValue { @@ -492,6 +499,19 @@ impl DataResponse { _ => Err(self), } } + + pub fn into_refresh_read(self) -> Result { + match self { + Self::RefreshRead(refresh_read) => Ok(refresh_read), + _ => Err(self), + } + } +} + +impl ScanResponse { + pub fn into_parts(self) -> (Vec, Vec) { + (self.resume_key, self.rows) + } } impl ClusterDescriptor { diff --git a/src/protos/protos/seamdb.proto b/src/protos/protos/seamdb.proto index 8bbf7a5..6ebd7e5 100644 --- a/src/protos/protos/seamdb.proto +++ b/src/protos/protos/seamdb.proto @@ -134,6 +134,14 @@ message Value { } } +enum ValueType { + Absent = 0; + Int = 1; + Float = 2; + Bytes = 3; + String = 4; +} + message TimestampedValue { Value value = 1; Timestamp timestamp = 2; @@ -151,6 +159,13 @@ message KeyValue { Value value = 2; } +message Transient { + oneof value { + Timestamp timestamp = 1; + TxnMeta transaction = 2; + } +} + message Temporal { oneof value { Timestamp timestamp = 1; @@ -448,6 +463,8 @@ message RefreshReadRequest { } message RefreshReadResponse { + // Empty if full span refreshed. + bytes resume_key = 1; } message ParticipateTxnRequest { @@ -479,3 +496,48 @@ service TabletService { rpc Locate(LocateRequest) returns (LocateResponse); } + +message BatchError { + optional uint32 request_index = 1; + DataError error = 2; +} + +message DataError { + oneof error { + ConflictWriteError conflict_write = 1; + DataTypeMismatchError mismatch_data_type = 2; + ShardNotFoundError shard_not_found = 3; + TimestampMismatchError timestamp_mismatch = 4; + StoreError store = 5; + SimpleError internal = 6; + } +} + +message ConflictWriteError { + bytes key = 1; + Transient transient = 2; +} + +message DataTypeMismatchError { + bytes key = 1; + ValueType expect = 2; + ValueType actual = 3; +} + +message TimestampMismatchError { + bytes key = 1; + Timestamp actual = 2; +} + +message ShardNotFoundError { + bytes key = 1; + uint64 shard_id = 2; +} + +message StoreError { + string message = 1; +} + +message SimpleError { + string message = 1; +} diff --git a/src/protos/protos/sql.proto b/src/protos/protos/sql.proto new file mode 100644 index 0000000..f321d2e --- /dev/null +++ b/src/protos/protos/sql.proto @@ -0,0 +1,132 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package sql; + +import "seamdb.proto"; + +message DescriptorMeta { + uint64 id = 1; + string name = 2; + uint64 database_id = 3; + uint64 parent_id = 4; + seamdb.Timestamp timestamp = 5; + bytes blob = 6; +} + +message DatabaseDescriptor { + uint64 id = 1; + string name = 2; + + // Populated from kv layer. + seamdb.Timestamp timestamp = 3; +} + +message SchemaDescriptor { + uint64 id = 1; + string name = 2; + uint64 database_id = 3; + + // Populated from kv layer. + seamdb.Timestamp timestamp = 4; +} + +message TableDescriptor { + uint64 id = 1; + string name = 2; + + uint64 database_id = 3; + uint64 schema_id = 4; + + // Populated from kv layer. + seamdb.Timestamp timestamp = 5; + + uint32 last_column_id = 7; + repeated ColumnDescriptor columns = 8; + uint32 last_index_id = 9; + repeated IndexDescriptor indices = 10; +} + +message ColumnType { + ColumnTypeKind kind = 1; + ColumnTypeDeclaration declaration = 2; +} + +enum ColumnTypeKind { + Boolean = 0; + Int16 = 1; + Int32 = 2; + Int64 = 3; + Float32 = 4; + Float64 = 5; + Bytes = 6; + String = 7; +} + +message NumericTypeDeclaration { + uint32 precision = 1; + uint32 scale = 2; +} + +message CharacterTypeDeclaration { + uint32 max_length = 1; +} + +message ColumnTypeDeclaration { + oneof value { + NumericTypeDeclaration numeric = 1; + CharacterTypeDeclaration character = 2; + } +} + +message StoringFloat32 { + float value = 1; +} + +message StoringFloat64 { + double value = 1; +} + +message ColumnValue { + oneof value { + bool boolean = 1; + int32 int16 = 2; + int32 int32 = 3; + int64 int64 = 4; + StoringFloat32 float32 = 5; + StoringFloat64 float64 = 6; + bytes bytes = 7; + string string = 8; + } +} + +message ColumnDescriptor { + uint32 id = 1; + string name = 2; + bool nullable = 3; + bool serial = 4; + ColumnTypeKind type_kind = 5; + ColumnTypeDeclaration type_declaration = 6; + ColumnValue default_value = 7; +} + +message IndexDescriptor { + uint32 id = 1; + string name = 2; + bool unique = 3; + repeated uint32 column_ids = 4; + repeated uint32 storing_column_ids = 5; +} diff --git a/src/protos/span.rs b/src/protos/span.rs index 6d472e3..f00df9c 100644 --- a/src/protos/span.rs +++ b/src/protos/span.rs @@ -31,6 +31,16 @@ impl KeyRange { Equal } } + + pub fn resume_from(&self, mut end: Vec) -> Vec { + match self.end < end { + true => { + end.clone_from(&self.end); + end + }, + false => Vec::default(), + } + } } impl From for KeySpan { @@ -84,6 +94,15 @@ impl KeySpan { } } + pub fn append_end(&self, end: &mut Vec) { + if self.end.is_empty() { + end.extend(&self.key); + end.push(0); + } else { + end.extend_from_slice(self.end.as_slice()); + } + } + pub fn into_end(mut self) -> Vec { match self.end.is_empty() { true => { @@ -167,7 +186,7 @@ impl<'a> KeySpanRef<'a> { Less => SpanOrdering::LessDisjoint, Equal => SpanOrdering::SubsetLeft, Greater => match self.key.cmp(other.end) { - Less => SpanOrdering::ContainAll, + Less => SpanOrdering::SubsetAll, Equal => SpanOrdering::GreaterContiguous, Greater => SpanOrdering::GreaterDisjoint, }, @@ -238,7 +257,7 @@ mod tests { assert_that!(KeySpanRef::new(b"k1", b"").compare(KeySpanRef::new(b"k1", b"k10"))) .is_equal_to(SpanOrdering::SubsetLeft); assert_that!(KeySpanRef::new(b"k1", b"").compare(KeySpanRef::new(b"k0", b"k10"))) - .is_equal_to(SpanOrdering::ContainAll); + .is_equal_to(SpanOrdering::SubsetAll); assert_that!(KeySpanRef::new(b"k1", b"").compare(KeySpanRef::new(b"k0", b"k1"))) .is_equal_to(SpanOrdering::GreaterContiguous); assert_that!(KeySpanRef::new(b"k1", b"").compare(KeySpanRef::new(b"k0", b"k01"))) @@ -249,7 +268,7 @@ mod tests { assert_that!(KeySpanRef::new(b"k1", b"k10").compare(KeySpanRef::new(b"k1", b""))) .is_equal_to(SpanOrdering::ContainLeft); assert_that!(KeySpanRef::new(b"k0", b"k10").compare(KeySpanRef::new(b"k1", b""))) - .is_equal_to(SpanOrdering::SubsetAll); + .is_equal_to(SpanOrdering::ContainAll); assert_that!(KeySpanRef::new(b"k0", b"k1").compare(KeySpanRef::new(b"k1", b""))) .is_equal_to(SpanOrdering::LessContiguous); assert_that!(KeySpanRef::new(b"k0", b"k01").compare(KeySpanRef::new(b"k1", b""))) diff --git a/src/protos/sql.rs b/src/protos/sql.rs new file mode 100644 index 0000000..e619616 --- /dev/null +++ b/src/protos/sql.rs @@ -0,0 +1,487 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::hash::{Hash, Hasher}; + +use bytes::{Buf, BufMut}; +use prost::Message; + +use crate::protos::{ + ColumnDescriptor, + ColumnTypeKind, + ColumnValue, + DatabaseDescriptor, + DescriptorMeta, + IndexDescriptor, + SchemaDescriptor, + StoringFloat32, + StoringFloat64, + TableDescriptor, +}; + +pub trait NamespaceDescriptor { + fn kind() -> &'static str; + + fn id(&self) -> u64; + fn name(&self) -> &str; + + fn from_meta(meta: DescriptorMeta) -> Self; +} + +impl NamespaceDescriptor for DatabaseDescriptor { + fn kind() -> &'static str { + "database" + } + + fn id(&self) -> u64 { + self.id + } + + fn name(&self) -> &str { + &self.name + } + + fn from_meta(meta: DescriptorMeta) -> Self { + Self { id: meta.id, name: meta.name, timestamp: meta.timestamp } + } +} + +impl NamespaceDescriptor for SchemaDescriptor { + fn kind() -> &'static str { + "schema" + } + + fn id(&self) -> u64 { + self.id + } + + fn name(&self) -> &str { + &self.name + } + + fn from_meta(meta: DescriptorMeta) -> Self { + Self { id: meta.id, name: meta.name, database_id: meta.parent_id, timestamp: meta.timestamp } + } +} + +impl NamespaceDescriptor for TableDescriptor { + fn kind() -> &'static str { + "table" + } + + fn id(&self) -> u64 { + self.id + } + + fn name(&self) -> &str { + &self.name + } + + fn from_meta(meta: DescriptorMeta) -> Self { + let mut descriptor = TableDescriptor::decode(meta.blob.as_slice()).unwrap(); + descriptor.id = meta.id; + descriptor.name = meta.name; + descriptor.database_id = meta.database_id; + descriptor.schema_id = meta.parent_id; + descriptor.timestamp = meta.timestamp; + descriptor + } +} + +impl TableDescriptor { + pub const POSTGRESQL_DIALECT_PREFIX: &[u8] = b"dus0"; + + pub fn serial_key(&self, column: &ColumnDescriptor) -> Vec { + let mut key = Vec::new(); + key.push(b't'); + key.put_u64(self.id); + key.push(b'c'); + key.put_u32(column.id); + key + } + + pub fn primary_index(&self) -> &IndexDescriptor { + for index in self.indices.iter() { + if index.id == 1 { + return index; + } + } + panic!("table {}(id={}) has no primary index", self.name, self.id) + } + + pub fn table_prefix(&self) -> Vec { + let mut key = Vec::new(); + key.push(b't'); + key.put_u64(self.id); + key + } + + pub fn index_prefix(&self, index: &IndexDescriptor) -> Vec { + let mut key = self.table_prefix(); + key.push(b'i'); + key.put_u32(index.id); + key + } + + pub fn find_column(&self, name: &str) -> Option<&ColumnDescriptor> { + self.columns.iter().find(|column| column.name == name) + } + + pub fn column(&self, id: u32) -> Option<&ColumnDescriptor> { + self.columns.iter().find(|column| column.id == id) + } + + pub fn decode_storing_columns(&self, index: &IndexDescriptor, mut bytes: &[u8]) -> Vec> { + let mut values = Vec::with_capacity(index.storing_column_ids.len()); + loop { + let id = bytes.get_u32(); + if id == 0 { + break; + } + let value = ColumnValue::decode(&mut bytes); + let i = values.len(); + if i >= index.storing_column_ids.len() { + values.push(value); + panic!("index key has more columns than descriptor {index:?}: {values:?}") + } + if id != index.storing_column_ids[i] { + panic!("index descriptor {index:?} does not have column id {id}") + } + let column = self.column(id).unwrap(); + match value.as_ref().map(|v| v.type_kind()) { + None => { + if !column.nullable { + panic!("column {} expect value, but get null", column.name) + } + }, + Some(type_kind) => { + if type_kind != column.type_kind { + panic!("index {index:?} column {column:?} mismatch column value {value:?}") + } + }, + } + values.push(value); + } + if values.len() != index.storing_column_ids.len() { + panic!("index {:?} get mismatching values {:?}", index, values) + } + values + } + + pub fn decode_index_key(&self, index: &IndexDescriptor, mut bytes: &[u8]) -> Vec { + let prefix = self.index_prefix(index); + bytes.advance(prefix.len()); + let mut values = Vec::with_capacity(index.column_ids.len()); + loop { + let id = bytes.get_u32(); + if id == 0 { + break; + } + let value = ColumnValue::decode_from_key(&mut bytes); + let i = values.len(); + if i >= index.column_ids.len() { + values.push(value); + panic!("index key has more columns than descriptor {index:?}: {values:?}") + } + if id != index.column_ids[i] { + panic!("index descriptor {index:?} does not have column id {id}") + } + let column = self.column(id).unwrap(); + if value.type_kind() != column.type_kind { + panic!("index {index:?} column {column:?} mismatch column value {value:?}") + } + values.push(value); + } + if values.len() != index.column_ids.len() { + panic!("index {:?} get mismatching keys {:?}", index, values) + } + values + } +} + +impl IndexDescriptor { + pub fn is_primary(&self) -> bool { + self.id == 1 + } + + pub fn sole_column(&self) -> Option { + if self.column_ids.len() == 1 { + return Some(self.column_ids[0]); + } + None + } +} + +impl ColumnValue { + pub fn minimum_of(type_kind: ColumnTypeKind) -> Self { + match type_kind { + ColumnTypeKind::Boolean => Self::Boolean(false), + ColumnTypeKind::Int16 => Self::Int16(i16::MIN.into()), + ColumnTypeKind::Int32 => Self::Int32(i32::MIN), + ColumnTypeKind::Int64 => Self::Int64(i64::MIN), + ColumnTypeKind::Float32 => Self::Float32(f32::MIN.into()), + ColumnTypeKind::Float64 => Self::Float64(f64::MIN.into()), + ColumnTypeKind::Bytes => Self::Bytes(Default::default()), + ColumnTypeKind::String => Self::String(Default::default()), + } + } + + pub fn type_kind(&self) -> ColumnTypeKind { + match self { + Self::Boolean(_) => ColumnTypeKind::Boolean, + Self::Int16(_) => ColumnTypeKind::Int16, + Self::Int32(_) => ColumnTypeKind::Int32, + Self::Int64(_) => ColumnTypeKind::Int64, + Self::Float32(_) => ColumnTypeKind::Float32, + Self::Float64(_) => ColumnTypeKind::Float64, + Self::Bytes(_) => ColumnTypeKind::Bytes, + Self::String(_) => ColumnTypeKind::String, + } + } + + pub fn into_u64(self) -> u64 { + match self { + Self::Int16(i) => i as u64, + Self::Int32(i) => i as u64, + Self::Int64(i) => i as u64, + _ => panic!("expect int, but get{self:?}"), + } + } + + pub fn into_string(self) -> String { + match self { + Self::String(s) => s, + _ => panic!("expect string, but get{self:?}"), + } + } + + pub fn into_bytes(self) -> Vec { + match self { + Self::Bytes(bytes) => bytes, + _ => panic!("expect bytes, but get{self:?}"), + } + } + + pub fn encode(&self, buf: &mut impl BufMut) { + buf.put_i32(self.type_kind() as i32); + match &self { + Self::Boolean(v) => buf.put_u8(if *v { 1 } else { 0 }), + Self::Int16(v) => buf.put_i16(*v as i16), + Self::Int32(v) => buf.put_i32(*v), + Self::Int64(v) => buf.put_i64(*v), + Self::Float32(v) => buf.put_f32(v.value), + Self::Float64(v) => buf.put_f64(v.value), + Self::Bytes(bytes) => { + buf.put_u32(bytes.len() as u32); + buf.put(bytes.as_slice()) + }, + Self::String(string) => { + buf.put_u32(string.len() as u32); + buf.put(string.as_bytes()) + }, + } + } + + pub fn encode_as_key(&self, buf: &mut impl BufMut) { + buf.put_i32(self.type_kind() as i32); + match self { + Self::Boolean(v) => buf.put_u8(if *v { 1 } else { 0 }), + Self::Int16(v) => buf.put_i16(*v as i16), + Self::Int32(v) => buf.put_i32(*v), + Self::Int64(v) => buf.put_i64(*v), + Self::Float32(v) => buf.put_f32(v.value), + Self::Float64(v) => buf.put_f64(v.value), + Self::Bytes(bytes) => Self::encode_escaped(buf, bytes), + Self::String(string) => Self::encode_escaped(buf, string.as_bytes()), + } + } + + pub fn decode_from_key(buf: &mut impl Buf) -> Self { + let type_kind = buf.get_i32(); + match ColumnTypeKind::try_from(type_kind).unwrap() { + ColumnTypeKind::Boolean => { + let v = buf.get_u8() == 1; + ColumnValue::Boolean(v) + }, + ColumnTypeKind::Int16 => { + let i = buf.get_i16(); + ColumnValue::Int16(i.into()) + }, + ColumnTypeKind::Int32 => { + let i = buf.get_i32(); + ColumnValue::Int32(i) + }, + ColumnTypeKind::Int64 => { + let i = buf.get_i64(); + ColumnValue::Int64(i) + }, + ColumnTypeKind::Float32 => { + let v = buf.get_f32(); + ColumnValue::Float32(v.into()) + }, + ColumnTypeKind::Float64 => { + let v = buf.get_f64(); + ColumnValue::Float64(v.into()) + }, + ColumnTypeKind::Bytes => { + let bytes = ColumnValue::decode_escaped(buf); + ColumnValue::Bytes(bytes) + }, + ColumnTypeKind::String => { + let bytes = ColumnValue::decode_escaped(buf); + let string = String::from_utf8(bytes).unwrap(); + ColumnValue::String(string) + }, + } + } + + // Cockroach style encoding. + // + // Alternative: https://github.com/facebook/mysql-5.6/wiki/MyRocks-record-format#memcomparable-format + fn encode_escaped(buf: &mut impl BufMut, mut bytes: &[u8]) { + loop { + let Some(i) = bytes.iter().position(|b| *b == 0) else { + break; + }; + let (precedings, remainings) = unsafe { (bytes.get_unchecked(0..i), bytes.get_unchecked(i + 1..)) }; + buf.put_slice(precedings); + buf.put_u8(0x00); + buf.put_u8(0xff); + bytes = remainings; + } + buf.put_slice(bytes); + buf.put_u8(0x00); + buf.put_u8(0x01); + } + + fn decode_escaped(buf: &mut dyn Buf) -> Vec { + let mut bytes = Vec::new(); + let chunk = buf.chunk(); + let mut remainings = chunk; + loop { + let Some(n) = remainings.iter().position(|b| *b == 0) else { + panic!("no terminal in escaped bytes: {remainings:?}") + }; + if n + 1 >= remainings.len() { + panic!("invalid escaped bytes: {remainings:?}") + } + bytes.extend_from_slice(&remainings[0..n]); + let escaped = remainings[n + 1]; + remainings = &remainings[n + 2..]; + match escaped { + 0x01 => break, + 0xff => bytes.push(0x00), + _ => panic!("invalid escaped bytes: {chunk:?}"), + }; + } + buf.advance(chunk.len() - remainings.len()); + bytes + } + + #[allow(clippy::uninit_vec)] + pub fn decode(buf: &mut dyn Buf) -> Option { + let type_kind = buf.get_i32(); + if type_kind == -1 { + return None; + } + match ColumnTypeKind::try_from(type_kind).unwrap() { + ColumnTypeKind::Boolean => { + let v = buf.get_u8() == 1; + Some(ColumnValue::Boolean(v)) + }, + ColumnTypeKind::Int16 => { + let i = buf.get_i16(); + Some(ColumnValue::Int16(i.into())) + }, + ColumnTypeKind::Int32 => { + let i = buf.get_i32(); + Some(ColumnValue::Int32(i)) + }, + ColumnTypeKind::Int64 => { + let i = buf.get_i64(); + Some(ColumnValue::Int64(i)) + }, + ColumnTypeKind::Float32 => { + let v = buf.get_f32(); + Some(ColumnValue::Float32(v.into())) + }, + ColumnTypeKind::Float64 => { + let v = buf.get_f64(); + Some(ColumnValue::Float64(v.into())) + }, + ColumnTypeKind::Bytes => { + let n = buf.get_u32() as usize; + let mut bytes = Vec::with_capacity(n); + unsafe { + bytes.set_len(n); + } + buf.copy_to_slice(&mut bytes); + Some(ColumnValue::Bytes(bytes)) + }, + ColumnTypeKind::String => { + let n = buf.get_u32() as usize; + let mut bytes = Vec::with_capacity(n); + unsafe { + bytes.set_len(n); + } + buf.copy_to_slice(&mut bytes); + Some(ColumnValue::String(String::from_utf8(bytes).unwrap())) + }, + } + } +} + +impl Eq for StoringFloat32 {} + +impl Hash for StoringFloat32 { + fn hash(&self, state: &mut H) + where + H: Hasher, { + state.write_u32(self.value.to_bits()) + } +} + +impl From for StoringFloat32 { + fn from(value: f32) -> Self { + Self { value } + } +} + +impl From for f32 { + fn from(value: StoringFloat32) -> Self { + value.value + } +} + +impl Eq for StoringFloat64 {} + +impl Hash for StoringFloat64 { + fn hash(&self, state: &mut H) + where + H: Hasher, { + state.write_u64(self.value.to_bits()) + } +} + +impl From for StoringFloat64 { + fn from(value: f64) -> Self { + Self { value } + } +} + +impl From for f64 { + fn from(value: StoringFloat64) -> Self { + value.value + } +} diff --git a/src/protos/temporal.rs b/src/protos/temporal.rs index e5b3fd1..c98bc19 100644 --- a/src/protos/temporal.rs +++ b/src/protos/temporal.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::cmp::Ordering::{self, *}; +use std::fmt::{Display, Formatter, Result}; use std::time::Duration; use super::*; @@ -50,6 +51,11 @@ impl Temporal { impl Transaction { pub const HEARTBEAT_INTERVAL: Duration = Duration::from_millis(500); + pub fn new(key: Vec, start_ts: Timestamp) -> Self { + let meta = TxnMeta { id: Uuid::new_random(), key, epoch: 0, start_ts, priority: 0 }; + Transaction { meta, ..Default::default() } + } + pub fn comparer() -> impl Fn(&Transaction, &Transaction) -> Ordering { let comparer = TxnMeta::comparer(); move |a, b| comparer(&a.meta, &b.meta) @@ -199,6 +205,37 @@ impl TxnMeta { } } +impl Display for TxnMeta { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!( + f, + "txn(id={},epoch={},key:{:?},start_ts={},priority={})", + self.id(), + self.epoch(), + self.key(), + self.start_ts, + self.priority + ) + } +} + +impl Display for Transaction { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!( + f, + "txn(id={},epoch={},key:{:?},start_ts={},priority={},status={:?},commit_ts={},heartbeat_ts={})", + self.id(), + self.epoch(), + self.key(), + self.meta.start_ts, + self.meta.priority, + self.status, + self.commit_ts(), + self.heartbeat_ts, + ) + } +} + impl From for Transaction { fn from(meta: TxnMeta) -> Self { Self { @@ -220,12 +257,36 @@ impl From for TxnMeta { } } +impl From for Temporal { + fn from(txn: Transaction) -> Self { + Temporal::Transaction(txn) + } +} + impl From for Temporal { fn from(t: Timestamp) -> Self { Temporal::Timestamp(t) } } +impl From for Transient { + fn from(txn: Transaction) -> Self { + Self::Transaction(txn.meta) + } +} + +impl From for Transient { + fn from(txn: TxnMeta) -> Self { + Self::Transaction(txn) + } +} + +impl From for Transient { + fn from(t: Timestamp) -> Self { + Self::Timestamp(t) + } +} + pub trait HasTxnMeta { fn meta(&self) -> &TxnMeta; diff --git a/src/sql/client.rs b/src/sql/client.rs new file mode 100644 index 0000000..1eaa14b --- /dev/null +++ b/src/sql/client.rs @@ -0,0 +1,675 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use bytes::{Buf, BufMut}; +use tracing::trace; + +use super::error::SqlError; +use crate::kv::{KvClient, KvError, KvSemantics, LazyInitTimestampedKvClient}; +use crate::protos::{ + ColumnDescriptor, + ColumnTypeKind, + ColumnValue, + DatabaseDescriptor, + DescriptorMeta, + IndexDescriptor, + KeyRange, + NamespaceDescriptor, + SchemaDescriptor, + TableDescriptor, + Timestamp, + TimestampedKeyValue, + Value, +}; +use crate::tablet::TabletClient; +use crate::txn::LazyInitTxn; + +#[derive(Clone)] +pub struct SqlClient { + semantics: KvSemantics, + client: Arc, +} + +impl SqlClient { + pub fn new(client: Arc, semantics: KvSemantics) -> Self { + match (semantics, client.semantics()) { + (KvSemantics::Inconsistent, KvSemantics::Inconsistent) + | (KvSemantics::Snapshot, KvSemantics::Snapshot) + | (KvSemantics::Transactional, KvSemantics::Transactional) => Self { semantics, client }, + (KvSemantics::Inconsistent, _) => Self { semantics, client: Arc::new(client.client().clone()) }, + (KvSemantics::Snapshot, _) => { + Self { semantics, client: Arc::new(LazyInitTimestampedKvClient::new(client.client().clone())) } + }, + (KvSemantics::Transactional, _) => { + Self { semantics, client: Arc::new(LazyInitTxn::new(client.client().clone())) } + }, + } + } + + pub async fn insert_rows_once( + &self, + table: &TableDescriptor, + rows: impl Iterator, + ) -> Result { + for row in rows { + trace!("insert row: {row:?}"); + for index in table.indices.iter() { + let (key, value) = row.index_kv(table, index); + self.client.put(&key, Some(Value::Bytes(value)), Some(Timestamp::default())).await?; + } + } + Ok(self.client.commit().await?) + } + + pub async fn delete_rows( + &self, + table: &TableDescriptor, + rows: impl Iterator, + expect_ts: Option, + ) -> Result<(), SqlError> { + for row in rows { + trace!("delete row: {row:?}"); + for index in table.indices.iter() { + let key = row.index_key(table, index); + self.client.put(&key, None, expect_ts).await?; + } + } + Ok(()) + } + + pub async fn get_database(&self, name: &str) -> Result, SqlError> { + self.get_descriptor(0, 0, name.to_string()).await + } + + pub async fn get_schema(&self, database_id: u64, name: &str) -> Result, SqlError> { + self.get_descriptor(0, database_id, name.to_string()).await + } + + pub async fn get_table( + &self, + _database_id: u64, + schema_id: u64, + name: &str, + ) -> Result, SqlError> { + self.get_descriptor(0, schema_id, name.to_string()).await + } + + // Table: + // id ==> descriptor + // index: name (unique) + pub async fn create_database(&self, name: String, if_not_exists: bool) -> Result { + let table = DatabasesTable::new(); + let mut database_row = Row::default(); + database_row.add_column(Column::with_value(table.parent_id_column().id, ColumnValue::Int64(0))); + database_row.add_column(Column::with_value(table.name_column().id, ColumnValue::String(name.clone()))); + let name_key = database_row.index_key(table.descriptor(), table.naming_index()); + if let Some((timestamp, value)) = self.get(&name_key).await? { + let id = decode_serial_primary_value(table.descriptor(), table.id_column(), &value)?; + return Ok(DatabaseDescriptor { id, name, timestamp }); + } + + let mut schema_row = Row::default(); + schema_row.add_column(Column::with_value(table.name_column().id, ColumnValue::String("public".to_string()))); + let serial_id_key = table.serial_id_key(); + + loop { + if let Some((timestamp, value)) = self.get(&name_key).await? { + if !if_not_exists { + return Err(SqlError::DatabaseAlreadyExists(name)); + } + let id = decode_serial_primary_value(table.descriptor(), table.id_column(), &value)?; + return Ok(DatabaseDescriptor { id, name, timestamp }); + } + self.fill_create_database_rows(&table, &serial_id_key, &mut database_row, &mut schema_row).await?; + match self + .insert_rows_once( + table.descriptor(), + std::iter::once(&database_row).chain(std::iter::once(&schema_row)), + ) + .await + { + Ok(ts) => { + return Ok(DatabaseDescriptor { + id: database_row.serial_column_value(table.id_column()), + name, + timestamp: ts, + }) + }, + Err(SqlError::UniqueIndexAlreadyExists) => match if_not_exists { + true => self.restart()?, + false => return Err(SqlError::DatabaseAlreadyExists(name)), + }, + Err(err) => match err.is_retriable() { + true => continue, + false => return Err(err), + }, + } + } + } + + // FIXME: Background job to drop table content. + // FIXME: DeleteRequest to batch delete without scanning. + pub async fn drop_table(&self, table: &TableDescriptor) -> Result { + let databases = DatabasesTable::new(); + let mut row = Row::default(); + row.add_column(Column::with_value(databases.id_column().id, ColumnValue::Int64(table.id as i64))); + row.add_column(Column::with_value(databases.parent_id_column().id, ColumnValue::Int64(table.schema_id as i64))); + row.add_column(Column::with_value(databases.name_column().id, ColumnValue::String(table.name.clone()))); + self.delete_rows(databases.descriptor(), std::iter::once(&row), Some(table.timestamp)).await?; + let mut start = table.table_prefix(); + let end = { + let mut bytes = start.clone(); + bytes.put_u8(0xff); + bytes + }; + loop { + let (resume_key, rows) = self.scan(&start, &end, 0).await?; + for row in rows { + self.put(&row.key, None, None).await?; + } + if resume_key.is_empty() || resume_key >= end { + break; + } + start = resume_key; + } + self.commit().await?; + Ok(true) + } + + // Table: + // id ==> descriptor + // index: name (unique) + pub async fn create_descriptor( + &self, + kind: &str, + parent_id: u64, + name: String, + blob: Vec, + if_not_exists: bool, + ) -> Result<(bool, u64, Timestamp), SqlError> { + let table = DatabasesTable::new(); + let mut row = Row::default(); + row.add_column(Column::with_value(table.parent_id_column().id, ColumnValue::Int64(parent_id as i64))); + row.add_column(Column::with_value(table.name_column().id, ColumnValue::String(name.clone()))); + row.add_column(Column::with_value(table.descriptor_column().id, ColumnValue::Bytes(blob))); + let name_key = row.index_key(table.descriptor(), table.naming_index()); + if let Some((timestamp, value)) = self.get(&name_key).await? { + if if_not_exists { + let id = decode_serial_primary_value(table.descriptor(), table.id_column(), &value)?; + return Ok((false, id, timestamp)); + } + return Err(SqlError::identity_already_exists(kind, name)); + } + + loop { + if let Some((timestamp, value)) = self.get(&name_key).await? { + if !if_not_exists { + return Err(SqlError::DatabaseAlreadyExists(name)); + } + let id = decode_serial_primary_value(table.descriptor(), table.id_column(), &value)?; + return Ok((false, id, timestamp)); + } + self.prefill_row(table.descriptor(), &mut row).await?; + match self.insert_rows_once(table.descriptor(), std::iter::once(&row)).await { + Ok(ts) => return Ok((true, row.serial_column_value(table.id_column()), ts)), + Err(SqlError::UniqueIndexAlreadyExists) => match if_not_exists { + true => self.restart()?, + false => return Err(SqlError::DatabaseAlreadyExists(name)), + }, + Err(err) => match err.is_retriable() { + true => continue, + false => { + return Err(err); + }, + }, + } + } + } + + pub async fn prefill_row(&self, table: &TableDescriptor, row: &mut Row) -> Result<(), SqlError> { + for column_descriptor in table.columns.iter() { + let Some(column) = row.find_column(column_descriptor.id) else { + if let Some(default_value) = &column_descriptor.default_value { + row.add_column(Column::with_value(column_descriptor.id, default_value.clone())); + continue; + } else if column_descriptor.nullable { + row.add_column(Column::new_null(column_descriptor.id)); + continue; + } else if column_descriptor.serial { + let key = table.serial_key(column_descriptor); + let incremented = self.increment(&key, 1).await?; + let value = match column_descriptor.type_kind { + ColumnTypeKind::Int16 => { + if incremented > i16::MAX as i64 { + return Err(SqlError::unexpected(format!( + "column {} overflow", + column_descriptor.name + ))); + } + ColumnValue::Int16(incremented as i32) + }, + ColumnTypeKind::Int32 => { + if incremented > i32::MAX as i64 { + return Err(SqlError::unexpected(format!( + "column {} overflow", + column_descriptor.name + ))); + } + ColumnValue::Int32(incremented as i32) + }, + ColumnTypeKind::Int64 => ColumnValue::Int64(incremented), + type_kind => { + return Err(SqlError::unexpected(format!( + "column {} has type {:?}, is not a serial column type", + column_descriptor.name, type_kind + ))); + }, + }; + row.add_column(Column::with_value(column_descriptor.id, value)); + continue; + } + return Err(SqlError::MissingColumn(column_descriptor.name.clone())); + }; + column.check(table, column_descriptor)?; + } + Ok(()) + } + + async fn fill_create_database_rows( + &self, + table: &DatabasesTable, + serial_id_key: &[u8], + database_row: &mut Row, + schema_row: &mut Row, + ) -> Result<(), SqlError> { + let incremented = self.increment(serial_id_key, 2).await?; + let database_id = incremented - 1; + let schema_id = incremented; + database_row.add_column(Column::with_value(table.id_column().id, ColumnValue::Int64(database_id))); + database_row.add_column(Column::new_null(table.descriptor_column().id)); + schema_row.add_column(Column::with_value(table.id_column().id, ColumnValue::Int64(schema_id))); + schema_row.add_column(Column::with_value(table.parent_id_column().id, ColumnValue::Int64(database_id))); + schema_row.add_column(Column::new_null(table.descriptor_column().id)); + Ok(()) + } + + pub async fn get_descriptor( + &self, + database_id: u64, + parent_id: u64, + name: String, + ) -> Result, SqlError> { + let table = DatabasesTable::new(); + let mut row = Row::default(); + row.add_column(Column::with_value(table.parent_id_column().id, ColumnValue::Int64(parent_id as i64))); + row.add_column(Column::with_value(table.name_column().id, ColumnValue::String(name))); + let index = table.naming_index(); + + let key = row.index_key(table.descriptor(), index); + let Some((ts, value)) = self.get(&key).await? else { + return Ok(None); + }; + let (id, name, blob) = table.decode_columns(index, TimestampedKeyValue { timestamp: ts, key, value }); + let meta = DescriptorMeta { id, name, database_id, parent_id, timestamp: ts, blob }; + Ok(Some(T::from_meta(meta))) + } + + pub async fn list_descriptors( + &self, + database_id: u64, + parent_id: u64, + ) -> Result, SqlError> { + let table = DatabasesTable::new(); + let index = table.naming_index(); + let mut row = Row::default(); + row.add_column(Column::with_value(table.parent_id_column().id, ColumnValue::Int64(parent_id as i64))); + let KeyRange { mut start, end } = row.index_range(table.descriptor(), index); + let mut descriptors = Vec::new(); + loop { + let (resume_key, rows) = self.scan(&start, &end, 0).await?; + for row in rows { + let timestamp = row.timestamp; + let (id, name, blob) = table.decode_columns(index, row); + let meta = DescriptorMeta { id, name, database_id, parent_id, timestamp, blob }; + descriptors.push(T::from_meta(meta)); + } + if resume_key.is_empty() || resume_key >= end { + break; + } + start = resume_key; + } + Ok(descriptors) + } +} + +#[async_trait::async_trait] +impl KvClient for SqlClient { + fn client(&self) -> &TabletClient { + self.client.client() + } + + fn semantics(&self) -> KvSemantics { + self.semantics + } + + async fn get(&self, key: &[u8]) -> Result, KvError> { + self.client.get(key).await + } + + async fn scan(&self, start: &[u8], end: &[u8], limit: u32) -> Result<(Vec, Vec), KvError> { + self.client.scan(start, end, limit).await + } + + async fn put(&self, key: &[u8], value: Option, expect_ts: Option) -> Result { + self.client.put(key, value, expect_ts).await + } + + async fn increment(&self, key: &[u8], increment: i64) -> Result { + self.client.increment(key, increment).await + } + + async fn commit(&self) -> Result { + self.client.commit().await + } + + async fn abort(&self) -> Result<(), KvError> { + self.client.abort().await + } + + fn restart(&self) -> Result<(), KvError> { + self.client.restart() + } +} + +// tenant (k,v) +// meta +// database +// table +// +// databases +// rows: +// id => name +// indices: +// name => id +// tables +// +// table/index/primary_keys ==> (version, key1, key2) + +#[derive(Debug, Clone)] +pub struct Column { + id: u32, + value: Option, +} + +impl Column { + pub fn new(id: u32, value: Option) -> Self { + Self { id, value } + } + + pub fn with_value(id: u32, value: ColumnValue) -> Self { + Self { id, value: Some(value) } + } + + pub fn new_null(id: u32) -> Self { + Self { id, value: None } + } + + pub fn check(&self, table: &TableDescriptor, column: &ColumnDescriptor) -> Result<(), SqlError> { + let Some(value) = &self.value else { + if !column.nullable { + return Err(SqlError::NotNullableColumn { table: table.name.clone(), column: column.name.clone() }); + } + return Ok(()); + }; + let value_type = value.type_kind(); + if value_type != column.type_kind { + return Err(SqlError::MismatchColumnType { + table: table.name.clone(), + column: column.name.clone(), + expect: column.type_kind, + actual: value_type, + }); + } + Ok(()) + } +} + +#[derive(Debug, Clone, Default)] +pub struct Row { + columns: HashMap, +} + +impl Row { + pub fn serial_column_value(&self, descriptor: &ColumnDescriptor) -> u64 { + let Some(column) = self.columns.get(&descriptor.id) else { + panic!("row does not have column {}", descriptor.name) + }; + match &column.value { + Some(ColumnValue::Int16(i)) => *i as u64, + Some(ColumnValue::Int32(i)) => *i as u64, + Some(ColumnValue::Int64(i)) => *i as u64, + _ => panic!("row does not have serial column {}, value {:?}", descriptor.name, column.value), + } + } + + pub fn find_column(&self, id: u32) -> Option<&Column> { + self.columns.get(&id) + } + + pub fn take_column(&mut self, id: u32) -> Option { + self.columns.remove(&id).and_then(|c| c.value) + } + + pub fn index_range(&self, table: &TableDescriptor, index: &IndexDescriptor) -> KeyRange { + let mut key = table.index_prefix(index); + for id in index.column_ids.iter().copied() { + match self.find_column(id) { + None => { + let mut end = key.clone(); + key.put_i32(id as i32); + end.put_i32((id + 1) as i32); + return KeyRange::new(key, end); + }, + Some(column) => { + key.put_i32(id as i32); + column.value.as_ref().unwrap().encode_as_key(&mut key); + }, + }; + } + let mut end = key.clone(); + key.put_u32(0); + end.put_u32(1); + KeyRange::new(key, end) + } + + pub fn index_key(&self, table: &TableDescriptor, index: &IndexDescriptor) -> Vec { + let mut key = table.index_prefix(index); + for id in index.column_ids.iter().copied() { + let column = self.find_column(id).unwrap(); + key.put_u32(column.id); + match column.value.as_ref() { + None => key.put_i32(-1), + Some(value) => value.encode_as_key(&mut key), + } + } + key.put_u32(0); + key + } + + pub fn index_kv(&self, table: &TableDescriptor, index: &IndexDescriptor) -> (Vec, Vec) { + let key = self.index_key(table, index); + let mut value_bytes = Vec::new(); + for id in index.storing_column_ids.iter().copied() { + let column = self.find_column(id).unwrap_or_else(|| panic!("no column {id} for index {}", index.id)); + value_bytes.put_u32(column.id); + match column.value.as_ref() { + None => value_bytes.put_i32(-1), + Some(value) => value.encode(&mut value_bytes), + } + } + value_bytes.put_u32(0); + (key, value_bytes) + } + + pub fn add_column(&mut self, column: Column) { + self.columns.insert(column.id, column); + } +} + +pub struct DatabasesTable { + descriptor: TableDescriptor, +} + +impl DatabasesTable { + pub fn new() -> Self { + Self { descriptor: databases_table_descriptor() } + } + + pub fn descriptor(&self) -> &TableDescriptor { + &self.descriptor + } + + pub fn serial_id_key(&self) -> Vec { + let id_column = self.id_column(); + self.descriptor.serial_key(id_column) + } + + pub fn id_column(&self) -> &ColumnDescriptor { + &self.descriptor.columns[0] + } + + pub fn parent_id_column(&self) -> &ColumnDescriptor { + &self.descriptor.columns[1] + } + + pub fn name_column(&self) -> &ColumnDescriptor { + &self.descriptor.columns[2] + } + + pub fn descriptor_column(&self) -> &ColumnDescriptor { + &self.descriptor.columns[3] + } + + pub fn naming_index(&self) -> &IndexDescriptor { + &self.descriptor.indices[1] + } + + pub fn decode_columns(&self, index: &IndexDescriptor, kv: TimestampedKeyValue) -> (u64, String, Vec) { + let mut row = Row::default(); + for (i, value) in self.descriptor.decode_index_key(index, &kv.key).into_iter().enumerate() { + row.add_column(Column::new(index.column_ids[i], Some(value))); + } + let index_values = + self.descriptor.decode_storing_columns(index, kv.value.read_bytes(&kv.key, "sql key value").unwrap()); + for (i, value) in index_values.into_iter().enumerate() { + row.add_column(Column::new(index.storing_column_ids[i], value)); + } + let id = row.take_column(self.id_column().id).unwrap().into_u64(); + let name = row.take_column(self.name_column().id).unwrap().into_string(); + let bytes = row.take_column(self.descriptor_column().id).map(|v| v.into_bytes()).unwrap_or(Vec::default()); + (id, name, bytes) + } +} + +fn databases_table_descriptor() -> TableDescriptor { + TableDescriptor { + id: 0, + database_id: 1, + schema_id: 2, + timestamp: Timestamp::ZERO, + name: "_databases".to_string(), + columns: vec![ + ColumnDescriptor { + id: 1, + name: "id".to_string(), + nullable: false, + serial: true, + type_kind: ColumnTypeKind::Int64, + type_declaration: None, + default_value: None, + }, + ColumnDescriptor { + id: 2, + name: "parent_id".to_string(), + nullable: false, + serial: false, + type_kind: ColumnTypeKind::Int64, + type_declaration: None, + default_value: None, + }, + ColumnDescriptor { + id: 3, + name: "name".to_string(), + nullable: false, + serial: false, + type_kind: ColumnTypeKind::String, + type_declaration: None, + default_value: None, + }, + ColumnDescriptor { + id: 4, + name: "descriptor".to_string(), + nullable: true, + serial: false, + type_kind: ColumnTypeKind::Bytes, + type_declaration: None, + default_value: None, + }, + ], + last_column_id: 4, + indices: vec![ + IndexDescriptor { + id: 1, + name: "primary".to_string(), + unique: true, + column_ids: vec![1], + storing_column_ids: vec![2, 3, 4], + }, + IndexDescriptor { + id: 2, + name: "naming".to_string(), + unique: true, + column_ids: vec![2, 3], + storing_column_ids: vec![1, 4], + }, + ], + last_index_id: 2, + } +} + +// FIXME: +fn decode_serial_primary_value( + table: &TableDescriptor, + column: &ColumnDescriptor, + value: &Value, +) -> Result { + let Value::Bytes(bytes) = value else { + panic!("column {}.{}expect bytes, get {:?}", table.name, column.name, value) + }; + let mut bytes = bytes.as_slice(); + let column_id = bytes.get_u32(); + if column_id != column.id { + panic!("expect column {}.{}(id:{}), but get {}", table.name, column.name, column.id, column_id) + } + match ColumnValue::decode(&mut bytes) { + None => panic!("no column value"), + Some(ColumnValue::Int16(i)) => Ok(i as u64), + Some(ColumnValue::Int32(i)) => Ok(i as u64), + Some(ColumnValue::Int64(i)) => Ok(i as u64), + Some(value) => panic!("get value with type {:?}", value.type_kind()), + } +} diff --git a/src/sql/descriptor.rs b/src/sql/descriptor.rs new file mode 100644 index 0000000..645e088 --- /dev/null +++ b/src/sql/descriptor.rs @@ -0,0 +1,80 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use datafusion::catalog::TableProvider; +use datafusion::catalog_common::ResolvedTableReference; +use datafusion::sql::TableReference; + +use super::client::SqlClient; +use super::error::SqlError; +use crate::protos::{DatabaseDescriptor, SchemaDescriptor}; +use crate::sql::plan::SqlTable; + +struct DatabaseMeta { + #[allow(unused)] + descriptor: DatabaseDescriptor, + schemas: Vec, +} + +pub struct TableDescriptorFetcher<'a> { + client: &'a SqlClient, + databases: HashMap, + tables: HashMap>, +} + +impl<'a> TableDescriptorFetcher<'a> { + pub fn new(client: &'a SqlClient) -> Self { + Self { client, databases: Default::default(), tables: Default::default() } + } + + pub async fn get_table( + &mut self, + table_ref: &ResolvedTableReference, + ) -> Result>, SqlError> { + if let Some(table) = self.tables.get(&TableReference::from(table_ref.clone())) { + return Ok(Some(table.clone())); + }; + let database = match self.databases.get(table_ref.catalog.as_ref()) { + Some(database) => database, + None => { + let Some(database_descriptor) = + self.client.get_descriptor::(0, 0, table_ref.catalog.to_string()).await? + else { + return Ok(None); + }; + let schema_descriptors = + self.client.list_descriptors(database_descriptor.id, database_descriptor.id).await?; + self.databases.insert(table_ref.catalog.to_string(), DatabaseMeta { + descriptor: database_descriptor, + schemas: schema_descriptors, + }); + self.databases.get(table_ref.catalog.as_ref()).unwrap() + }, + }; + let Some(schema) = database.schemas.iter().find(|d| d.name == table_ref.schema.as_ref()) else { + return Ok(None); + }; + let Some(table_descriptor) = + self.client.get_descriptor(schema.database_id, schema.id, table_ref.table.to_string()).await? + else { + return Ok(None); + }; + let table = Arc::new(SqlTable::new(table_descriptor)); + self.tables.insert(table_ref.clone().into(), table.clone()); + Ok(Some(table)) + } +} diff --git a/src/sql/error.rs b/src/sql/error.rs new file mode 100644 index 0000000..ce794ae --- /dev/null +++ b/src/sql/error.rs @@ -0,0 +1,107 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::borrow::Cow; + +use datafusion::common::DataFusionError; +use datafusion::sql::sqlparser::parser::ParserError; +use thiserror::Error; + +use crate::kv::KvError; +use crate::protos::ColumnTypeKind; +use crate::tablet::TabletClientError; + +#[derive(Debug, Error)] +pub enum SqlError { + #[error("no statement")] + NoStatement, + #[error("multiple statements")] + MultipleStatements, + #[error("missing column {0}")] + MissingColumn(String), + #[error("database {0} not exists")] + DatabaseNotExists(String), + #[error("database {0} already exists")] + DatabaseAlreadyExists(String), + #[error("schema {0} not exists")] + SchemaNotExists(String), + #[error("schema {0} already exists")] + SchemaAlreadyExists(String), + #[error("table {0} already exists")] + TableAlreadyExists(String), + #[error("{0}")] + TabletClientError(#[from] TabletClientError), + #[error("{0}")] + KvError(#[from] KvError), + #[error("{table}.{column} is not nullable")] + NotNullableColumn { table: String, column: String }, + #[error("{table}.{column} expect type {expect:?} but get {actual:?}")] + MismatchColumnType { table: String, column: String, expect: ColumnTypeKind, actual: ColumnTypeKind }, + #[error("index already exist")] + UniqueIndexAlreadyExists, + #[error("{0}")] + ExecutorError(#[from] DataFusionError), + #[error("{0} unimplemented")] + Unimplemented(Cow<'static, str>), + #[error("invalid sql: {0}")] + Invalid(Cow<'static, str>), + #[error("{0} unsupported")] + Unsupported(Cow<'static, str>), + #[error("unexpected error: {0}")] + Unexpected(Cow<'static, str>), +} + +impl SqlError { + pub fn unexpected(msg: impl Into>) -> Self { + Self::Unexpected(msg.into()) + } + + pub fn unsupported(feature: impl Into>) -> Self { + Self::Unsupported(feature.into()) + } + + pub fn unimplemented(msg: impl Into>) -> Self { + Self::Unimplemented(msg.into()) + } + + pub fn invalid(msg: impl Into>) -> Self { + Self::Invalid(msg.into()) + } + + pub fn is_retriable(&self) -> bool { + // FIXME + false + } + + pub fn identity_already_exists(kind: &str, name: String) -> Self { + match kind { + "table" => Self::TableAlreadyExists(name), + "schema" => Self::SchemaAlreadyExists(name), + "database" => Self::DatabaseAlreadyExists(name), + _ => Self::unexpected(format!("{kind} {name} already exists")), + } + } +} + +impl From for SqlError { + fn from(err: ParserError) -> SqlError { + SqlError::from(DataFusionError::from(err)) + } +} + +impl From for DataFusionError { + fn from(err: SqlError) -> Self { + DataFusionError::External(Box::new(err)) + } +} diff --git a/src/sql/mod.rs b/src/sql/mod.rs new file mode 100644 index 0000000..f46d3ca --- /dev/null +++ b/src/sql/mod.rs @@ -0,0 +1,253 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod client; +mod descriptor; +mod error; +mod plan; +pub mod postgres; +mod shared; +mod traits; + +use std::collections::HashMap; +use std::sync::Arc; + +use datafusion::catalog::TableProvider; +use datafusion::catalog_common::ResolvedTableReference; +use datafusion::common::plan_datafusion_err; +use datafusion::execution::config::SessionConfig; +use datafusion::execution::session_state::{SessionState, SessionStateBuilder}; +use datafusion::execution::SendableRecordBatchStream; +use datafusion::logical_expr::logical_plan::dml::InsertOp; +use datafusion::logical_expr::{CreateCatalog, DdlStatement, DmlStatement, DropTable, LogicalPlan, WriteOp}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; +use datafusion::sql::TableReference; + +use self::client::SqlClient; +use self::descriptor::TableDescriptorFetcher; +pub use self::error::SqlError; +use self::plan::{CreateDatabaseExec, DropTableExec}; +use self::postgres::PostgresPlanner; +use self::traits::*; +use crate::kv::{KvClient, KvSemantics}; + +pub struct PostgreSqlExecutor { + client: Arc, + state: SessionState, + planner: DefaultPhysicalPlanner, +} + +#[async_trait::async_trait] +impl PlannerContext for PostgreSqlExecutor { + fn state(&self) -> &SessionState { + &self.state + } + + async fn fetch_table_references( + &self, + table_references: Vec, + ) -> Result>, SqlError> { + let mut tables = HashMap::with_capacity(table_references.len()); + let client = SqlClient::new(self.client.clone(), KvSemantics::Snapshot); + let mut fetcher = TableDescriptorFetcher::new(&client); + for table_ref in table_references { + let Some(table) = fetcher.get_table(&table_ref).await? else { + continue; + }; + tables.insert(TableReference::from(table_ref), table); + } + Ok(tables) + } +} + +impl PostgreSqlExecutor { + pub fn new(client: Arc, database: String) -> Self { + let mut config = SessionConfig::default(); + config.options_mut().catalog.create_default_catalog_and_schema = false; + config.options_mut().catalog.default_catalog = database; + config.options_mut().catalog.information_schema = true; + let state = SessionStateBuilder::new().with_default_features().with_config(config).build(); + let planner = DefaultPhysicalPlanner::with_extension_planners(plan::get_extension_planners()); + Self { client, state, planner } + } + + pub async fn execute_sql(&self, sql: &str) -> Result { + let planner = PostgresPlanner::new(self); + let (plan, tables) = planner.plan(sql).await?; + self.execute_plan(plan, tables).await + } + + async fn create_physical_plan( + &self, + plan: LogicalPlan, + tables: &HashMap>, + ) -> Result, SqlError> { + let execution_plan: Arc = match plan { + LogicalPlan::Ddl(DdlStatement::CreateExternalTable(_)) => { + return Err(SqlError::unsupported("create external table")) + }, + LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(_)) => { + return Err(SqlError::unexpected("ddl CreateTable")) + }, + LogicalPlan::Ddl(DdlStatement::CreateView(_)) => return Err(SqlError::unimplemented("create view")), + LogicalPlan::Ddl(DdlStatement::CreateCatalogSchema(_)) => { + return Err(SqlError::unimplemented("create schema")) + }, + LogicalPlan::Ddl(DdlStatement::CreateCatalog(CreateCatalog { catalog_name, if_not_exists, .. })) => { + Arc::new(CreateDatabaseExec::new(catalog_name, if_not_exists)) + }, + LogicalPlan::Ddl(DdlStatement::CreateIndex(_)) => return Err(SqlError::unimplemented("create index")), + LogicalPlan::Ddl(DdlStatement::DropTable(DropTable { name, if_exists, .. })) => { + let name = TableReference::from(self.resolve_table_reference(name)); + Arc::new(DropTableExec::new(name, if_exists)) + }, + LogicalPlan::Ddl(DdlStatement::DropView(_)) => return Err(SqlError::unimplemented("drop view")), + LogicalPlan::Ddl(DdlStatement::DropCatalogSchema(_)) => return Err(SqlError::unimplemented("drop schema")), + LogicalPlan::Ddl(DdlStatement::CreateFunction(_)) => { + return Err(SqlError::unimplemented("create function")) + }, + LogicalPlan::Ddl(DdlStatement::DropFunction(_)) => return Err(SqlError::unimplemented("drop function")), + LogicalPlan::Dml(DmlStatement { table_name, op: WriteOp::Insert(InsertOp::Append), input, .. }) => { + let table_name = TableReference::from(self.resolve_table_reference(table_name)); + let table = tables.get(&table_name).ok_or_else(|| plan_datafusion_err!("no table {table_name}"))?; + let input_exec = Box::pin(self.create_physical_plan(Arc::unwrap_or_clone(input), tables)).await?; + table.insert_into(&self.state, input_exec, InsertOp::Append).await? + }, + _ => { + let optimized_plan = self.state.optimize(&plan)?; + self.planner.create_physical_plan(&optimized_plan, &self.state).await? + }, + }; + + Ok(execution_plan) + } + + async fn execute_plan( + &self, + plan: LogicalPlan, + tables: HashMap>, + ) -> Result { + let execution_plan = self.create_physical_plan(plan, &tables).await?; + let mut state = self.state.clone(); + let client = SqlClient::new(self.client.clone(), KvSemantics::Transactional); + state.config_mut().set_extension(Arc::new(client.clone())); + let stream = execution_plan.execute(0, state.task_ctx())?; + Ok(stream) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::time::Duration; + + use datafusion::common::arrow::array::{Array, Int32Array, Int64Array, StringArray, UInt64Array}; + use futures::prelude::stream::StreamExt; + use tokio::net::TcpListener; + + use super::PostgreSqlExecutor; + use crate::cluster::tests::etcd_container; + use crate::cluster::{ClusterEnv, EtcdClusterMetaDaemon, EtcdNodeRegistry, NodeId}; + use crate::endpoint::{Endpoint, Params}; + use crate::log::{LogManager, MemoryLogFactory}; + use crate::protos::TableDescriptor; + use crate::tablet::{TabletClient, TabletNode}; + + #[test_log::test(tokio::test)] + #[tracing_test::traced_test] + async fn query() { + let etcd = etcd_container(); + let cluster_uri = etcd.uri().with_path("/team1/seamdb1").unwrap(); + + let node_id = NodeId::new_random(); + let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap(); + let address = format!("http://{}", listener.local_addr().unwrap()); + let endpoint = Endpoint::try_from(address.as_str()).unwrap(); + let (nodes, lease) = + EtcdNodeRegistry::join(cluster_uri.clone(), node_id.clone(), Some(endpoint.to_owned())).await.unwrap(); + let log_manager = + LogManager::new(MemoryLogFactory::new(), &MemoryLogFactory::ENDPOINT, &Params::default()).await.unwrap(); + let cluster_env = ClusterEnv::new(log_manager.into(), nodes).with_replicas(1); + let mut cluster_meta_handle = + EtcdClusterMetaDaemon::start("seamdb1", cluster_uri.clone(), cluster_env.clone()).await.unwrap(); + let descriptor_watcher = cluster_meta_handle.watch_descriptor(None).await.unwrap(); + let deployment_watcher = cluster_meta_handle.watch_deployment(None).await.unwrap(); + let cluster_env = cluster_env.with_descriptor(descriptor_watcher).with_deployment(deployment_watcher.monitor()); + let _node = TabletNode::start(node_id, listener, lease, cluster_env.clone()); + let client = TabletClient::new(cluster_env).scope(TableDescriptor::POSTGRESQL_DIALECT_PREFIX); + tokio::time::sleep(Duration::from_secs(20)).await; + + let executor = PostgreSqlExecutor::new(Arc::new(client), "test1".to_string()); + let mut stream = executor.execute_sql("CREATE DATABASE test1").await.unwrap(); + while let Some(_record) = stream.next().await {} + + let mut stream = executor + .execute_sql( + r#"CREATE TABLE table1 ( + id serial PRIMARY KEY, + count bigint, + price real, + description text + );"#, + ) + .await + .unwrap(); + while let Some(_record) = stream.next().await {} + + let mut stream = executor + .execute_sql( + r#"INSERT INTO table1 + (count, price, description) + VALUES + (4, 15.6, NULL), + (3, 7.8, 'NNNNNN'), + (8, 3.4, 'a'), + (8, 2.9, 'b'); + "#, + ) + .await + .unwrap(); + let record = stream.next().await.unwrap().unwrap(); + let column = record.column(0); + let array = column.as_any().downcast_ref::().unwrap(); + assert_eq!(array.value(0), 4u64); + + let mut stream = executor + .execute_sql("select id, count, description from table1 ORDER BY count DESC, id ASC;") + .await + .unwrap(); + + let record = stream.next().await.unwrap().unwrap(); + let column_id = record.column(0).as_any().downcast_ref::().unwrap(); + let column_count = record.column(1).as_any().downcast_ref::().unwrap(); + let column_description = record.column(2).as_any().downcast_ref::().unwrap(); + + assert_eq!(column_id.value(0), 3); + assert_eq!(column_count.value(0), 8); + assert_eq!(column_description.value(0), "a"); + + assert_eq!(column_id.value(1), 4); + assert_eq!(column_count.value(1), 8); + assert_eq!(column_description.value(1), "b"); + + assert_eq!(column_id.value(2), 1); + assert_eq!(column_count.value(2), 4); + assert!(column_description.is_null(2)); + + assert_eq!(column_id.value(3), 2); + assert_eq!(column_count.value(3), 3); + assert_eq!(column_description.value(3), "NNNNNN"); + } +} diff --git a/src/sql/plan/catalog.rs b/src/sql/plan/catalog.rs new file mode 100644 index 0000000..2acb3e5 --- /dev/null +++ b/src/sql/plan/catalog.rs @@ -0,0 +1,98 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +use datafusion::common::arrow::datatypes::Schema; +use datafusion::common::{DataFusionError, Result as DFResult}; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::stream::RecordBatchReceiverStreamBuilder; +use datafusion::physical_plan::{ + DisplayAs, + DisplayFormatType, + ExecutionMode, + ExecutionPlan, + Partitioning, + PlanProperties, +}; + +use crate::sql::client::SqlClient; + +#[derive(Debug)] +pub struct CreateDatabaseExec { + name: String, + if_not_exists: bool, + schema: Arc, + properties: PlanProperties, +} + +impl CreateDatabaseExec { + pub fn new(name: String, if_not_exists: bool) -> Self { + let schema = Arc::new(Schema::empty()); + let properties = PlanProperties::new( + EquivalenceProperties::new(schema.clone()), + Partitioning::UnknownPartitioning(1), + ExecutionMode::Bounded, + ); + Self { name, if_not_exists, schema, properties } + } +} + +impl DisplayAs for CreateDatabaseExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "CreateDatabase({})", self.name) + } +} + +impl ExecutionPlan for CreateDatabaseExec { + fn name(&self) -> &str { + "InsertExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn execute(&self, _partition: usize, context: Arc) -> DFResult { + let client = context + .session_config() + .get_extension::() + .ok_or_else(|| DataFusionError::Execution("no sql client".to_string()))?; + let mut builder = RecordBatchReceiverStreamBuilder::new(self.schema.clone(), 1); + let sender = builder.tx(); + let name = self.name.clone(); + let if_not_exists = self.if_not_exists; + builder.spawn(async move { + client.create_database(name, if_not_exists).await?; + drop(sender); + Ok(()) + }); + Ok(builder.build()) + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children(self: Arc, _children: Vec>) -> DFResult> { + Ok(self) + } +} diff --git a/src/sql/plan/create_table.rs b/src/sql/plan/create_table.rs new file mode 100644 index 0000000..bd78ee5 --- /dev/null +++ b/src/sql/plan/create_table.rs @@ -0,0 +1,219 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::fmt::{self, Formatter}; +use std::sync::Arc; + +use datafusion::catalog_common::ResolvedTableReference; +use datafusion::common::arrow::array::array::StringArray; +use datafusion::common::arrow::datatypes::{self, Field, Schema}; +use datafusion::common::arrow::record_batch::RecordBatch; +use datafusion::common::{DFSchema, DataFusionError}; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::logical_expr::{Expr, Extension, LogicalPlan, UserDefinedLogicalNodeCore}; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::stream::RecordBatchReceiverStreamBuilder; +use datafusion::physical_plan::{ + DisplayAs, + DisplayFormatType, + ExecutionMode, + ExecutionPlan, + Partitioning, + PlanProperties, +}; +use datafusion::sql::TableReference; +use ignore_result::Ignore; +use lazy_static::lazy_static; +use prost::Message; +use tokio::sync::mpsc; + +use crate::sql::client::SqlClient; +use crate::sql::SqlError; + +lazy_static! { + pub static ref CREATE_TABLE_SCHEMA: Arc = + Arc::new(Schema::new([Arc::new(Field::new("result", datatypes::DataType::Utf8, false))])); +} + +use crate::protos::TableDescriptor; +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive_where::derive_where(PartialOrd)] +pub struct CreateTablePlan { + #[derive_where(skip(EqHashOrd))] + pub schema: Arc, + pub table_ref: TableReference, + pub table_descriptor: TableDescriptor, + pub if_not_exists: bool, +} + +impl CreateTablePlan { + pub fn new(reference: TableReference, descriptor: TableDescriptor, if_not_exists: bool) -> Self { + let schema = Arc::new(DFSchema::try_from(CREATE_TABLE_SCHEMA.clone()).unwrap()); + Self { table_ref: reference, table_descriptor: descriptor, if_not_exists, schema } + } +} + +impl From for LogicalPlan { + fn from(plan: CreateTablePlan) -> LogicalPlan { + LogicalPlan::Extension(Extension { node: Arc::new(plan) }) + } +} + +impl UserDefinedLogicalNodeCore for CreateTablePlan { + fn name(&self) -> &str { + "CreateTablePlan" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![] + } + + fn schema(&self) -> &Arc { + &self.schema + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "CreateTable({})", self.table_ref) + } + + fn with_exprs_and_inputs(&self, exprs: Vec, inputs: Vec) -> Result { + if !exprs.is_empty() || !inputs.is_empty() { + return Err(DataFusionError::Internal("CreateTablePlan has no inputs and expressions".to_string())); + } + Ok(self.clone()) + } +} + +#[derive(Debug)] +pub struct CreateTableExec { + table_ref: ResolvedTableReference, + descriptor: TableDescriptor, + if_not_exists: bool, + properties: PlanProperties, +} + +impl CreateTableExec { + pub fn new(table_ref: ResolvedTableReference, descriptor: TableDescriptor, if_not_exists: bool) -> Self { + let properties = PlanProperties::new( + EquivalenceProperties::new(Schema::empty().into()), + Partitioning::UnknownPartitioning(1), + ExecutionMode::Bounded, + ); + Self { table_ref, descriptor, if_not_exists, properties } + } +} + +impl DisplayAs for CreateTableExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { + write!(f, "CreateTableExec(table: {})", self.table_ref) + } +} + +impl ExecutionPlan for CreateTableExec { + fn name(&self) -> &str { + "CreateTableExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn execute( + &self, + _partition: usize, + context: Arc, + ) -> Result { + let client = context + .session_config() + .get_extension::() + .ok_or_else(|| DataFusionError::Execution("no sql kv client".to_string()))?; + let mut builder = RecordBatchReceiverStreamBuilder::new(CREATE_TABLE_SCHEMA.clone(), 1); + let sender = builder.tx(); + builder.spawn(create_table( + client, + self.table_ref.clone(), + self.descriptor.clone(), + self.if_not_exists, + sender, + )); + Ok(builder.build()) + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result, DataFusionError> { + Ok(self) + } +} + +async fn create_table_internally( + client: Arc, + reference: ResolvedTableReference, + descriptor: TableDescriptor, + if_not_exists: bool, +) -> Result { + let database = match client.get_database(&reference.catalog).await? { + None => return Err(SqlError::DatabaseNotExists(reference.catalog.to_string()).into()), + Some(database) => database, + }; + let schema = match client.get_schema(database.id, &reference.schema).await? { + None => return Err(SqlError::SchemaNotExists(format!("{}.{}", reference.catalog, reference.schema)).into()), + Some(schema) => schema, + }; + if client.get_table(database.id, schema.id, &reference.table).await?.is_some() { + if if_not_exists { + let record = RecordBatch::try_new(CREATE_TABLE_SCHEMA.clone(), vec![Arc::new(StringArray::from(vec![ + "already exists", + ]))])?; + return Ok(record); + } + return Err(SqlError::TableAlreadyExists(reference.to_string()).into()); + }; + let blob = descriptor.encode_to_vec(); + let (created, _id, _ts) = + client.create_descriptor("table", schema.id, reference.table.to_string(), blob, if_not_exists).await?; + let record = match created { + true => RecordBatch::try_new(CREATE_TABLE_SCHEMA.clone(), vec![Arc::new(StringArray::from(vec!["created"]))])?, + false => RecordBatch::try_new(CREATE_TABLE_SCHEMA.clone(), vec![Arc::new(StringArray::from(vec![ + "already exists", + ]))])?, + }; + Ok(record) +} + +async fn create_table( + client: Arc, + reference: ResolvedTableReference, + descriptor: TableDescriptor, + if_not_exists: bool, + sender: mpsc::Sender>, +) -> Result<(), DataFusionError> { + let result = create_table_internally(client, reference, descriptor, if_not_exists).await; + sender.send(result).await.ignore(); + Ok(()) +} diff --git a/src/sql/plan/drop_table.rs b/src/sql/plan/drop_table.rs new file mode 100644 index 0000000..f77cc65 --- /dev/null +++ b/src/sql/plan/drop_table.rs @@ -0,0 +1,129 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::fmt::{self, Formatter}; +use std::sync::Arc; + +use datafusion::catalog_common::ResolvedTableReference; +use datafusion::common::arrow::datatypes::Schema; +use datafusion::common::arrow::record_batch::RecordBatch; +use datafusion::common::{DataFusionError, Result as DFResult}; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::stream::RecordBatchReceiverStreamBuilder; +use datafusion::physical_plan::{ + DisplayAs, + DisplayFormatType, + ExecutionMode, + ExecutionPlan, + Partitioning, + PlanProperties, +}; +use datafusion::sql::TableReference; +use ignore_result::Ignore; +use tokio::sync::mpsc; + +use super::table::SqlTable; +use crate::sql::client::SqlClient; +use crate::sql::descriptor::TableDescriptorFetcher; +use crate::sql::SqlError; +#[derive(Debug)] +pub struct DropTableExec { + name: TableReference, + if_exists: bool, + schema: Arc, + properties: PlanProperties, +} + +impl DropTableExec { + pub fn new(name: TableReference, if_exists: bool) -> Self { + let schema = Arc::new(Schema::empty()); + let properties = PlanProperties::new( + EquivalenceProperties::new(schema.clone()), + Partitioning::UnknownPartitioning(1), + ExecutionMode::Bounded, + ); + Self { name, if_exists, schema, properties } + } +} + +impl DisplayAs for DropTableExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { + write!(f, "DropTable({})", self.name) + } +} + +impl ExecutionPlan for DropTableExec { + fn name(&self) -> &str { + "DropTableExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn execute(&self, _partition: usize, context: Arc) -> DFResult { + let client = context + .session_config() + .get_extension::() + .ok_or_else(|| DataFusionError::Execution("no sql client".to_string()))?; + let mut builder = RecordBatchReceiverStreamBuilder::new(self.schema.clone(), 1); + let sender = builder.tx(); + let name = self.name.clone().resolve("", ""); + let if_exists = self.if_exists; + builder.spawn(drop_table(client, name, if_exists, sender)); + Ok(builder.build()) + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children(self: Arc, _children: Vec>) -> DFResult> { + Ok(self) + } +} + +async fn drop_table_internally( + client: Arc, + name: ResolvedTableReference, + if_exists: bool, +) -> Result { + let mut fetcher = TableDescriptorFetcher::new(&client); + let Some(table_provider) = fetcher.get_table(&name).await? else { + if if_exists { + return Ok(false); + } + return Err(SqlError::unexpected(format!("table {name} not exists"))); + }; + let table = table_provider.as_any().downcast_ref::().unwrap(); + client.drop_table(&table.descriptor).await +} + +async fn drop_table( + client: Arc, + name: ResolvedTableReference, + if_exists: bool, + sender: mpsc::Sender>, +) -> DFResult<()> { + if let Err(err) = drop_table_internally(client, name, if_exists).await { + sender.send(Err(err.into())).await.ignore(); + } + Ok(()) +} diff --git a/src/sql/plan/insert.rs b/src/sql/plan/insert.rs new file mode 100644 index 0000000..2e83f73 --- /dev/null +++ b/src/sql/plan/insert.rs @@ -0,0 +1,252 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +use datafusion::catalog::TableProvider; +use datafusion::common::arrow::array::{ + Array, + BinaryArray, + BooleanArray, + Float32Array, + Float64Array, + Int16Array, + Int32Array, + Int64Array, + StringArray, + UInt64Array, +}; +use datafusion::common::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::common::arrow::record_batch::RecordBatch; +use datafusion::common::{plan_err, DataFusionError, Result as DFResult, SchemaExt}; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::stream::RecordBatchReceiverStreamBuilder; +use datafusion::physical_plan::{ + DisplayAs, + DisplayFormatType, + ExecutionMode, + ExecutionPlan, + Partitioning, + PlanProperties, +}; +use futures::prelude::stream::StreamExt; +use ignore_result::Ignore; +use lazy_static::lazy_static; +use tokio::sync::mpsc; +use tracing::{instrument, trace}; + +use super::table::SqlTable; +use crate::protos::{ColumnTypeKind, ColumnValue}; +use crate::sql::client::{Column, Row, SqlClient}; + +lazy_static! { + pub static ref INSERT_COUNT_SCHEMA: Arc = + Arc::new(Schema::new(vec![Field::new("count", DataType::UInt64, false)])); +} + +#[derive(Debug)] +pub struct InsertExec { + table: SqlTable, + input: Arc, + properties: PlanProperties, +} + +impl InsertExec { + pub fn new(table: SqlTable, input: Arc) -> Self { + let properties = PlanProperties::new( + EquivalenceProperties::new(table.schema().clone()), + Partitioning::UnknownPartitioning(1), + ExecutionMode::Bounded, + ); + Self { table, input, properties } + } +} + +impl DisplayAs for InsertExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "InsertExec(table: {}, input: ", self.table.name())?; + self.input.fmt_as(t, f)?; + write!(f, ")") + } +} + +impl ExecutionPlan for InsertExec { + fn name(&self) -> &str { + "InsertExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn execute(&self, partition: usize, context: Arc) -> DFResult { + if !self.table.schema().logically_equivalent_names_and_types(&self.input.schema()) { + return plan_err!("insert expect schema {:?}, but get {:?}", self.table.schema(), self.input.schema()); + } + let client = context + .session_config() + .get_extension::() + .ok_or_else(|| DataFusionError::Execution("no sql client".to_string()))?; + let stream = self.input.execute(partition, context)?; + let mut builder = RecordBatchReceiverStreamBuilder::new(INSERT_COUNT_SCHEMA.clone(), 1); + let sender = builder.tx(); + builder.spawn(insert_into_table(client, self.table.clone(), stream, sender)); + Ok(builder.build()) + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children(self: Arc, _children: Vec>) -> DFResult> { + Ok(self) + } +} + +async fn insert_into_table_internally( + client: Arc, + table: SqlTable, + mut stream: SendableRecordBatchStream, +) -> DFResult { + let schema = stream.schema(); + let mut rows = vec![]; + while let Some(record) = stream.next().await { + let record = record?; + rows.extend(std::iter::repeat_n(Row::default(), record.num_rows())); + for (i, field) in schema.fields.iter().enumerate() { + let column_values = record.column(i); + if column_values.null_count() == record.num_rows() { + continue; + } + let column_desc = table.descriptor.find_column(field.name().as_str()).unwrap(); + match column_desc.type_kind { + ColumnTypeKind::Boolean => { + let array = column_values.as_any().downcast_ref::().unwrap(); + for (i, value) in array.values().iter().enumerate() { + let column = if array.is_null(i) { + Column::new_null(column_desc.id) + } else { + Column::with_value(column_desc.id, ColumnValue::Boolean(value)) + }; + rows[i].add_column(column); + } + }, + ColumnTypeKind::Int16 => { + let array = column_values.as_any().downcast_ref::().unwrap(); + for (i, value) in array.values().as_ref().iter().copied().enumerate() { + let column = if array.is_null(i) { + Column::new_null(column_desc.id) + } else { + Column::with_value(column_desc.id, ColumnValue::Int16(value.into())) + }; + rows[i].add_column(column); + } + }, + ColumnTypeKind::Int32 => { + let array = column_values.as_any().downcast_ref::().unwrap(); + for (i, value) in array.values().as_ref().iter().copied().enumerate() { + let column = if array.is_null(i) { + Column::new_null(column_desc.id) + } else { + Column::with_value(column_desc.id, ColumnValue::Int32(value)) + }; + rows[i].add_column(column); + } + }, + ColumnTypeKind::Int64 => { + let array = column_values.as_any().downcast_ref::().unwrap(); + for (i, value) in array.values().as_ref().iter().copied().enumerate() { + let column = if array.is_null(i) { + Column::new_null(column_desc.id) + } else { + Column::with_value(column_desc.id, ColumnValue::Int64(value)) + }; + rows[i].add_column(column); + } + }, + ColumnTypeKind::Float32 => { + let array = column_values.as_any().downcast_ref::().unwrap(); + for (i, value) in array.values().as_ref().iter().copied().enumerate() { + let column = if array.is_null(i) { + Column::new_null(column_desc.id) + } else { + Column::with_value(column_desc.id, ColumnValue::Float32(value.into())) + }; + rows[i].add_column(column); + } + }, + ColumnTypeKind::Float64 => { + let array = column_values.as_any().downcast_ref::().unwrap(); + for (i, value) in array.values().as_ref().iter().copied().enumerate() { + let column = if array.is_null(i) { + Column::new_null(column_desc.id) + } else { + Column::with_value(column_desc.id, ColumnValue::Float64(value.into())) + }; + rows[i].add_column(column); + } + }, + ColumnTypeKind::String => { + let array = column_values.as_any().downcast_ref::().unwrap(); + for (i, row) in rows.iter_mut().enumerate() { + let column = if array.is_null(i) { + Column::new_null(column_desc.id) + } else { + Column::with_value(column_desc.id, ColumnValue::String(array.value(i).to_owned())) + }; + row.add_column(column); + } + }, + ColumnTypeKind::Bytes => { + let array = column_values.as_any().downcast_ref::().unwrap(); + for (i, row) in rows.iter_mut().enumerate() { + let column = if array.is_null(i) { + Column::new_null(column_desc.id) + } else { + Column::with_value(column_desc.id, ColumnValue::Bytes(array.value(i).to_owned())) + }; + row.add_column(column); + } + }, + } + } + } + for row in rows.iter_mut() { + trace!("prefill row: {row:?}"); + client.prefill_row(&table.descriptor, row).await?; + } + client.insert_rows_once(&table.descriptor, rows.iter()).await?; + let record = + RecordBatch::try_new(INSERT_COUNT_SCHEMA.clone(), vec![Arc::new(UInt64Array::from(vec![rows.len() as u64]))])?; + Ok(record) +} + +#[instrument(skip_all, fields(table.name = table.descriptor.name))] +async fn insert_into_table( + client: Arc, + table: SqlTable, + stream: SendableRecordBatchStream, + sender: mpsc::Sender>, +) -> DFResult<()> { + let result = insert_into_table_internally(client, table, stream).await; + sender.send(result).await.ignore(); + Ok(()) +} diff --git a/src/sql/plan/mod.rs b/src/sql/plan/mod.rs new file mode 100644 index 0000000..d9051fb --- /dev/null +++ b/src/sql/plan/mod.rs @@ -0,0 +1,26 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod catalog; +mod create_table; +mod drop_table; +mod insert; +mod planner; +mod table; + +pub use self::catalog::*; +pub use self::create_table::*; +pub use self::drop_table::*; +pub use self::planner::*; +pub use self::table::*; diff --git a/src/sql/plan/planner.rs b/src/sql/plan/planner.rs new file mode 100644 index 0000000..34fbe14 --- /dev/null +++ b/src/sql/plan/planner.rs @@ -0,0 +1,51 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use datafusion::common::DataFusionError; +use datafusion::execution::session_state::SessionState; +use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNode}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner}; + +use super::*; + +struct PostgresPhsycialPlanner {} + +#[async_trait::async_trait] +impl ExtensionPlanner for PostgresPhsycialPlanner { + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + _logical_inputs: &[&LogicalPlan], + _physical_inputs: &[Arc], + _session_state: &SessionState, + ) -> Result>, DataFusionError> { + if let Some(plan) = node.as_any().downcast_ref::() { + let physical_plan = CreateTableExec::new( + plan.table_ref.clone().resolve("", ""), + plan.table_descriptor.clone(), + plan.if_not_exists, + ); + return Ok(Some(Arc::new(physical_plan))); + } + Ok(None) + } +} + +pub fn get_extension_planners() -> Vec> { + vec![Arc::new(PostgresPhsycialPlanner {})] +} diff --git a/src/sql/plan/table.rs b/src/sql/plan/table.rs new file mode 100644 index 0000000..35a5ea5 --- /dev/null +++ b/src/sql/plan/table.rs @@ -0,0 +1,432 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +use bytes::BufMut; +use datafusion::catalog::{Session, TableProvider}; +use datafusion::common::arrow::array::builder::{ + BinaryBuilder, + BooleanBuilder, + Float32Builder, + Float64Builder, + Int16Builder, + Int32Builder, + Int64Builder, + StringBuilder, + StructBuilder, +}; +use datafusion::common::arrow::datatypes::{self, Field, Fields, Schema, SchemaRef}; +use datafusion::common::arrow::record_batch::RecordBatch; +use datafusion::common::{not_impl_err, plan_err, Constraint, Constraints, DataFusionError, Result as DFResult}; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::logical_expr::logical_plan::dml::InsertOp; +use datafusion::logical_expr::{Expr, TableSource, TableType}; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::execution_plan::project_schema; +use datafusion::physical_plan::stream::RecordBatchReceiverStreamBuilder; +use datafusion::physical_plan::{ + DisplayAs, + DisplayFormatType, + ExecutionMode, + ExecutionPlan, + Partitioning, + PlanProperties, +}; +use tokio::sync::mpsc; + +use super::insert::InsertExec; +use crate::kv::KvClient; +use crate::protos::{ColumnTypeKind, ColumnValue, IndexDescriptor, TableDescriptor, TimestampedKeyValue}; +use crate::sql::client::SqlClient; +use crate::sql::error::SqlError; + +impl TryFrom<&datatypes::DataType> for ColumnTypeKind { + type Error = DataFusionError; + + fn try_from(value: &datatypes::DataType) -> DFResult { + match value { + datatypes::DataType::Int64 => Ok(ColumnTypeKind::Int64), + datatypes::DataType::Binary => Ok(ColumnTypeKind::Bytes), + datatypes::DataType::Utf8 => Ok(ColumnTypeKind::String), + _ => plan_err!("unsupported data type {value}"), + } + } +} + +impl From for datatypes::DataType { + fn from(kind: ColumnTypeKind) -> datatypes::DataType { + match kind { + ColumnTypeKind::Boolean => datatypes::DataType::Boolean, + ColumnTypeKind::Int16 => datatypes::DataType::Int16, + ColumnTypeKind::Int32 => datatypes::DataType::Int32, + ColumnTypeKind::Int64 => datatypes::DataType::Int64, + ColumnTypeKind::Float32 => datatypes::DataType::Float32, + ColumnTypeKind::Float64 => datatypes::DataType::Float64, + ColumnTypeKind::Bytes => datatypes::DataType::Binary, + ColumnTypeKind::String => datatypes::DataType::Utf8, + } + } +} + +#[derive(Clone, Debug)] +pub struct SqlTable { + pub descriptor: TableDescriptor, + schema: Arc, + constraints: Constraints, +} + +fn build_table_schema_and_constraints(descriptor: &TableDescriptor) -> (Schema, Constraints) { + let fields: Fields = descriptor + .columns + .iter() + .map(|column| Field::new(column.name.clone(), column.type_kind.into(), column.nullable)) + .collect(); + let schema = Schema::new(fields); + let mut constraints = Vec::with_capacity(descriptor.indices.len()); + for (i, index) in descriptor.indices.iter().enumerate() { + if !index.unique { + continue; + } + let indices = index + .column_ids + .iter() + .copied() + .map(|id| descriptor.columns.iter().position(|column| column.id == id).unwrap()) + .collect(); + let constraint = if i == 0 { Constraint::PrimaryKey(indices) } else { Constraint::Unique(indices) }; + constraints.push(constraint); + } + (schema, Constraints::new_unverified(constraints)) +} + +impl SqlTable { + pub fn new(descriptor: TableDescriptor) -> Self { + let (schema, constraints) = build_table_schema_and_constraints(&descriptor); + Self { descriptor, schema: Arc::new(schema), constraints } + } + + pub fn name(&self) -> &str { + &self.descriptor.name + } +} + +impl TableSource for SqlTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> Arc { + self.schema.clone() + } + + fn constraints(&self) -> Option<&Constraints> { + Some(&self.constraints) + } +} + +#[async_trait::async_trait] +impl TableProvider for SqlTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> Arc { + self.schema.clone() + } + + fn constraints(&self) -> Option<&Constraints> { + Some(&self.constraints) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result, DataFusionError> { + let schema = project_schema(&self.schema, projection)?; + Ok(Arc::new(SqlTableScanExec::new(self.clone(), schema))) + } + + async fn insert_into( + &self, + _state: &dyn Session, + input: Arc, + insert_op: InsertOp, + ) -> Result, DataFusionError> { + if insert_op != InsertOp::Append { + return not_impl_err!("INSERT INTO .. ON CONFLICT .. DO UPDATE SET .."); + } + let insert = InsertExec::new(self.clone(), input); + Ok(Arc::new(insert)) + } +} + +#[derive(Debug)] +struct SqlTableScanExec { + table: SqlTable, + schema: SchemaRef, + properties: PlanProperties, +} + +impl SqlTableScanExec { + pub fn new(table: SqlTable, schema: SchemaRef) -> Self { + let properties = PlanProperties::new( + EquivalenceProperties::new(schema.clone()), + Partitioning::UnknownPartitioning(1), + ExecutionMode::Bounded, + ); + Self { table, schema, properties } + } +} + +impl DisplayAs for SqlTableScanExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "SqlTableScanExec(table: {})", self.table.name()) + } +} + +impl ExecutionPlan for SqlTableScanExec { + fn name(&self) -> &str { + "SqlTableScanExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn execute( + &self, + _partition: usize, + context: Arc, + ) -> Result { + let client = context + .session_config() + .get_extension::() + .ok_or_else(|| DataFusionError::Execution("no sql client".to_string()))?; + let schema = self.schema.clone(); + let mut builder = RecordBatchReceiverStreamBuilder::new(schema.clone(), 128); + let table = self.table.clone(); + let sender = builder.tx(); + builder.spawn(scan_table(client, table, schema, sender)); + Ok(builder.build()) + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result, DataFusionError> { + Ok(self) + } +} + +struct ClusteredRowBuilder<'a> { + table: &'a TableDescriptor, + index: &'a IndexDescriptor, + columns: Vec>, + builder: StructBuilder, +} + +impl<'a> ClusteredRowBuilder<'a> { + pub fn new(fields: Fields, table: &'a TableDescriptor, index: &'a IndexDescriptor) -> Self { + let mut columns = Vec::with_capacity(index.column_ids.len() + index.storing_column_ids.len()); + for column_id in index.column_ids.iter().chain(index.storing_column_ids.iter()).copied() { + let column = table.column(column_id).unwrap(); + let Some((i, _field)) = fields.find(&column.name) else { + columns.push(Default::default()); + continue; + }; + columns.push(Some((i, column.type_kind))); + } + Self { table, index, columns, builder: StructBuilder::from_fields(fields, 128) } + } + + pub fn finish(&mut self) -> RecordBatch { + let array = self.builder.finish(); + RecordBatch::from(array) + } + + pub fn add_row(&mut self, row: TimestampedKeyValue) { + let keys = self.table.decode_index_key(self.index, &row.key); + let value_bytes = row.value.read_bytes(&row.key, "sql key value").unwrap(); + let values = self.table.decode_storing_columns(self.index, value_bytes); + for (i, value) in keys.iter().map(Some).chain(values.iter().map(Option::as_ref)).enumerate() { + let Some((j, type_kind)) = self.columns[i] else { + continue; + }; + self.add_field(j, type_kind, value); + } + self.builder.append(true); + } + + fn add_boolean_field(builder: &mut BooleanBuilder, value: Option<&ColumnValue>) { + let value = value.map(|v| match v { + ColumnValue::Boolean(v) => *v, + _ => panic!("not int16: {:?}", v), + }); + builder.append_option(value); + } + + fn add_int16_field(builder: &mut Int16Builder, value: Option<&ColumnValue>) { + let value = value.map(|v| match v { + ColumnValue::Int16(int) => *int as i16, + _ => panic!("not int16: {:?}", v), + }); + builder.append_option(value); + } + + fn add_int32_field(builder: &mut Int32Builder, value: Option<&ColumnValue>) { + let value = value.map(|v| match v { + ColumnValue::Int32(int) => *int, + _ => panic!("not int: {:?}", v), + }); + builder.append_option(value); + } + + fn add_int64_field(builder: &mut Int64Builder, value: Option<&ColumnValue>) { + let value = value.map(|v| match v { + ColumnValue::Int64(int) => *int, + _ => panic!("not int64: {:?}", v), + }); + builder.append_option(value); + } + + fn add_float32_field(builder: &mut Float32Builder, value: Option<&ColumnValue>) { + let value = value.map(|v| match v { + ColumnValue::Float32(v) => v.value, + _ => panic!("not int: {:?}", v), + }); + builder.append_option(value); + } + + fn add_float64_field(builder: &mut Float64Builder, value: Option<&ColumnValue>) { + let value = value.map(|v| match v { + ColumnValue::Float64(v) => v.value, + _ => panic!("not int64: {:?}", v), + }); + builder.append_option(value); + } + + fn add_binary_field(builder: &mut BinaryBuilder, value: Option<&ColumnValue>) { + let bytes = value.map(|v| match v { + ColumnValue::Bytes(bytes) => bytes.as_slice(), + _ => panic!("not int: {:?}", v), + }); + builder.append_option(bytes); + } + + fn add_string_field(builder: &mut StringBuilder, value: Option<&ColumnValue>) { + let bytes = value.map(|v| match v { + ColumnValue::String(string) => string.as_str(), + _ => panic!("not int: {:?}", v), + }); + builder.append_option(bytes); + } + + fn add_field(&mut self, i: usize, type_kind: ColumnTypeKind, value: Option<&ColumnValue>) { + match type_kind { + ColumnTypeKind::Boolean => { + let builder = self.builder.field_builder::(i).unwrap(); + Self::add_boolean_field(builder, value); + }, + ColumnTypeKind::Int16 => { + let builder = self.builder.field_builder::(i).unwrap(); + Self::add_int16_field(builder, value); + }, + ColumnTypeKind::Int32 => { + let builder = self.builder.field_builder::(i).unwrap(); + Self::add_int32_field(builder, value); + }, + ColumnTypeKind::Int64 => { + let builder = self.builder.field_builder::(i).unwrap(); + Self::add_int64_field(builder, value); + }, + ColumnTypeKind::Float32 => { + let builder = self.builder.field_builder::(i).unwrap(); + Self::add_float32_field(builder, value); + }, + ColumnTypeKind::Float64 => { + let builder = self.builder.field_builder::(i).unwrap(); + Self::add_float64_field(builder, value); + }, + ColumnTypeKind::Bytes => { + let builder = self.builder.field_builder::(i).unwrap(); + Self::add_binary_field(builder, value); + }, + ColumnTypeKind::String => { + let builder = self.builder.field_builder::(i).unwrap(); + Self::add_string_field(builder, value); + }, + } + } +} + +async fn scan_table_internally( + client: Arc, + table: SqlTable, + schema: SchemaRef, + sender: mpsc::Sender>, +) -> Result<(), SqlError> { + let primary_index = table.descriptor.primary_index(); + let mut start = table.descriptor.index_prefix(primary_index); + let index_end = { + let mut bytes = start.clone(); + bytes.put_u32(u32::MAX); + bytes + }; + let mut builder = ClusteredRowBuilder::new(schema.fields.clone(), &table.descriptor, primary_index); + loop { + let (resume_key, rows) = client.scan(&start, &index_end, 0).await?; + if !rows.is_empty() { + for row in rows { + builder.add_row(row); + } + let record_batch = builder.finish(); + if sender.send(Ok(record_batch)).await.is_err() { + break; + } + } + if resume_key.is_empty() || resume_key >= index_end { + break; + } + start = resume_key; + } + Ok(()) +} + +async fn scan_table( + client: Arc, + table: SqlTable, + schema: SchemaRef, + sender: mpsc::Sender>, +) -> Result<(), DataFusionError> { + scan_table_internally(client, table, schema, sender).await?; + Ok(()) +} diff --git a/src/sql/postgres.rs b/src/sql/postgres.rs new file mode 100644 index 0000000..c00e661 --- /dev/null +++ b/src/sql/postgres.rs @@ -0,0 +1,596 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::fmt::Debug; +use std::ops::ControlFlow; +use std::sync::Arc; + +use datafusion::catalog::TableProvider; +use datafusion::common::arrow::array::{ + Array, + BinaryArray, + BooleanArray, + Float32Array, + Float64Array, + Int16Array, + Int32Array, + Int64Array, + RecordBatch, + StringArray, + UInt64Array, +}; +use datafusion::common::arrow::datatypes::{self, Schema}; +use datafusion::common::{plan_datafusion_err, DataFusionError, Result as DFResult}; +use datafusion::config::ConfigOptions; +use datafusion::datasource::default_table_source::DefaultTableSource; +use datafusion::execution::session_state::SessionState; +use datafusion::execution::SendableRecordBatchStream; +use datafusion::logical_expr::sqlparser::ast::{CharacterLength, ColumnDef, CreateTable, Statement, TableConstraint}; +use datafusion::logical_expr::var_provider::{is_system_variables, VarType}; +use datafusion::logical_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; +use datafusion::sql::planner::{object_name_to_table_reference, ContextProvider, SqlToRel}; +use datafusion::sql::sqlparser::ast::{ColumnOption, DataType}; +use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; +use datafusion::sql::sqlparser::parser::Parser; +use datafusion::sql::TableReference; +use futures::prelude::stream::StreamExt; +use futures::{Sink, Stream}; +use pgwire::api::auth::noop::NoopStartupHandler; +use pgwire::api::copy::NoopCopyHandler; +use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; +use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response}; +use pgwire::api::{ClientInfo, PgWireHandlerFactory, Type, METADATA_DATABASE}; +use pgwire::error::{PgWireError, PgWireResult}; +use pgwire::messages::data::DataRow; +use pgwire::messages::PgWireBackendMessage; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tracing::debug; + +use super::error::SqlError; +use super::plan::*; +use super::shared::*; +use super::traits::*; +use crate::protos::{CharacterTypeDeclaration, ColumnTypeDeclaration, ColumnTypeKind}; +use crate::sql::PostgreSqlExecutor; +use crate::tablet::TabletClient; + +pub struct PostgresPlanner<'a> { + context: &'a (dyn PlannerContext + Sync), +} + +impl<'a> PostgresPlanner<'a> { + pub fn new(context: &'a (dyn PlannerContext + Sync)) -> Self { + Self { context } + } + + pub async fn plan( + &self, + sql: &str, + ) -> Result<(LogicalPlan, HashMap>), SqlError> { + let dialect = PostgreSqlDialect {}; + let mut statements = Parser::parse_sql(&dialect, sql)?; + let statement = match statements.len() { + 0 => return Err(SqlError::NoStatement), + 1 => statements.remove(0), + _ => return Err(SqlError::MultipleStatements), + }; + self.plan_statement(statement).await + } + + async fn plan_statement( + &self, + statement: Statement, + ) -> Result<(LogicalPlan, HashMap>), SqlError> { + let resolved_table_references = self.context.collect_table_references(&statement)?; + let tables = self.context.fetch_table_references(resolved_table_references).await?; + let provider = PostgresContextProvider::new(self.context.state(), &tables); + let planner = SqlToRel::new(&provider); + let plan = match statement { + Statement::CreateTable(CreateTable { + or_replace, + temporary, + external, + global, + if_not_exists, + name, + columns, + constraints, + .. + }) => { + if or_replace { + return Err(SqlError::unsupported("CREATE OR REPLACE TABLE ..")); + } + if temporary || global.is_some() { + return Err(SqlError::unsupported("CREATE [ GLOBAL | LOCAL ] { TEMPORARY | TEMP } TABLE ..")); + } + if external { + return Err(SqlError::unsupported("CREATE EXTERNAL TABLE ..")); + } + let table_ref = self.context.resolve_table_reference(object_name_to_table_reference(name, true)?); + if columns.is_empty() { + return Err(SqlError::invalid(format!("no columns in creating table {table_ref}"))); + } + let mut table_descriptor_builder = TableDescriptorBuilder::new(); + for ColumnDef { name, data_type, collation, options } in columns { + if collation.is_some() { + return Err(SqlError::unsupported("CREATE TABLE table_name (column_name .. COLLATE ..")); + } + let column_name = name.value; + let (serial, type_kind, type_declaration) = match data_type { + DataType::Boolean | DataType::Bool => (false, ColumnTypeKind::Boolean, None), + DataType::Text => (false, ColumnTypeKind::String, None), + DataType::Int(_) | DataType::Int4(_) | DataType::Integer(_) => { + (false, ColumnTypeKind::Int32, None) + }, + DataType::SmallInt(_) | DataType::Int2(_) => (false, ColumnTypeKind::Int16, None), + DataType::BigInt(_) | DataType::Int8(_) => (false, ColumnTypeKind::Int64, None), + DataType::Varchar(n) | DataType::CharacterVarying(n) => match n { + None => (false, ColumnTypeKind::String, None), + Some(CharacterLength::IntegerLength { unit: Some(_), .. }) => { + return Err(SqlError::invalid("character length units")) + }, + Some(CharacterLength::IntegerLength { length, .. }) => ( + false, + ColumnTypeKind::String, + Some(ColumnTypeDeclaration::Character(CharacterTypeDeclaration { + max_length: length as u32, + })), + ), + Some(CharacterLength::Max) => return Err(SqlError::invalid("varchar(MAX)")), + }, + DataType::Real | DataType::Float4 => (false, ColumnTypeKind::Float32, None), + DataType::DoublePrecision | DataType::Float8 => (false, ColumnTypeKind::Float64, None), + DataType::Custom(name, _modifiers) => { + let type_name = name.to_string(); + match type_name.as_str() { + "smallserial" | "serial2" => (true, ColumnTypeKind::Int16, None), + "serial" | "serial4" => (true, ColumnTypeKind::Int32, None), + "bigserial" | "serial8" => (true, ColumnTypeKind::Int64, None), + _ => return Err(SqlError::unsupported(format!("data type {type_name}"))), + } + }, + DataType::Bytea => (false, ColumnTypeKind::Bytes, None), + _ => return Err(SqlError::unsupported(format!("data type {data_type}"))), + }; + + let mut column_builder = + table_descriptor_builder.add_column(column_name, type_kind, type_declaration)?; + for option in options { + match option.option { + ColumnOption::Null => column_builder.set_nullable(true), + ColumnOption::NotNull => column_builder.set_nullable(false), + ColumnOption::Default(_) => return Err(SqlError::unimplemented("DEFAULT expr")), + ColumnOption::Unique { is_primary, characteristics } => { + if characteristics.is_some() { + return Err(SqlError::unsupported("CREATE TABLE .. [ DEFERRABLE | NOT DEFERRABLE ] [ INITIALLY DEFERRED | INITIALLY IMMEDIATE ]")); + } + column_builder + .add_unique_constraint(is_primary, option.name.map(|ident| ident.value))?; + }, + ColumnOption::Materialized(_) => { + return Err(SqlError::invalid("CREATE TABLE .. (column_name .. MATERIALIZE ..)")) + }, + ColumnOption::Ephemeral(_) => { + return Err(SqlError::invalid("CREATE TABLE .. (column_name .. EPHEMERAL ..)")) + }, + ColumnOption::Alias(_) => { + return Err(SqlError::invalid("CREATE TABLE .. (column_name .. ALIAS ..) ..")) + }, + ColumnOption::ForeignKey { .. } => { + return Err(SqlError::unsupported("CREATE TABLE .. FOREIGN KEY ..")) + }, + ColumnOption::Check(_) => { + return Err(SqlError::unsupported("CREATE TABLE .. (column_name .. CHECK ..) ..")) + }, + ColumnOption::OnUpdate(_) => { + return Err(SqlError::unsupported("CREATE TABLE .. (column_name .. ON UPDATE ..)")) + }, + ColumnOption::DialectSpecific(_) => {}, + ColumnOption::CharacterSet(_) => {}, + ColumnOption::Comment(_) => {}, + ColumnOption::Generated { .. } => { + return Err(SqlError::unsupported("CREATE TABLE .. (column_name .. GENERATED ..)")) + }, + ColumnOption::Options(_) => { + return Err(SqlError::unsupported("CREATE TABLE .. (column_name .. OPTIONS(..) ..)")) + }, + } + } + column_builder.set_serial(serial); + } + for constraint in constraints { + match constraint { + TableConstraint::Unique { name, index_name, columns, .. } => { + let index_name = index_name.or(name).map(|ident| ident.value); + let column_names = columns.into_iter().map(|ident| ident.value).collect(); + table_descriptor_builder.add_unique_index(index_name, column_names)?; + }, + TableConstraint::PrimaryKey { name, index_name, columns, .. } => { + let index_name = index_name.or(name).map(|ident| ident.value); + let column_names = columns.into_iter().map(|ident| ident.value).collect(); + table_descriptor_builder.add_primary_index(index_name, column_names)?; + }, + TableConstraint::ForeignKey { .. } => return Err(SqlError::unsupported("FOREIGN KEY")), + TableConstraint::Check { .. } => return Err(SqlError::unsupported("CHECK")), + TableConstraint::Index { name, columns, .. } => { + let index_name = name.map(|ident| ident.value); + table_descriptor_builder + .add_index(index_name, columns.into_iter().map(|ident| ident.value).collect())?; + }, + TableConstraint::FulltextOrSpatial { .. } => { + return Err(SqlError::unsupported( + "{FULLTEXT | SPATIAL} [INDEX | KEY] [index_name] (key_part,...)", + )) + }, + } + } + let table_descriptor = table_descriptor_builder.build()?; + CreateTablePlan::new(table_ref.into(), table_descriptor, if_not_exists).into() + }, + _ => planner.sql_statement_to_plan(statement)?, + }; + Ok((plan, tables)) + } +} + +struct PostgresContextProvider<'a> { + state: &'a SessionState, + tables: &'a HashMap>, +} + +impl<'a> PostgresContextProvider<'a> { + pub fn new(state: &'a SessionState, tables: &'a HashMap>) -> Self { + Self { state, tables } + } +} + +impl ContextProvider for PostgresContextProvider<'_> { + fn get_table_source(&self, name: TableReference) -> Result, DataFusionError> { + let resolved = self.state.resolve_table_reference(name).into(); + let Some(table) = self.tables.get(&resolved) else { + return Err(DataFusionError::Plan(format!("no table source for {}", resolved))); + }; + Ok(Arc::new(DefaultTableSource::new(table.clone()))) + } + + fn get_table_function_source(&self, name: &str, args: Vec) -> DFResult> { + let tbl_func = self + .state + .table_functions() + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?; + let provider = tbl_func.create_table_provider(&args)?; + + Ok(Arc::new(DefaultTableSource::new(provider))) + } + + fn get_function_meta(&self, name: &str) -> Option> { + self.state.scalar_functions().get(name).cloned() + } + + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.state.aggregate_functions().get(name).cloned() + } + + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions().get(name).cloned() + } + + fn get_variable_type(&self, variable_names: &[String]) -> Option { + if variable_names.is_empty() { + return None; + } + + let provider_type = if is_system_variables(variable_names) { VarType::System } else { VarType::UserDefined }; + + self.state + .execution_props() + .var_providers + .as_ref() + .and_then(|provider| provider.get(&provider_type)?.get_type(variable_names)) + } + + fn options(&self) -> &ConfigOptions { + self.state.config_options() + } + + fn udf_names(&self) -> Vec { + self.state.scalar_functions().keys().cloned().collect() + } + + fn udaf_names(&self) -> Vec { + self.state.aggregate_functions().keys().cloned().collect() + } + + fn udwf_names(&self) -> Vec { + self.state.window_functions().keys().cloned().collect() + } +} + +pub struct PostgresqlQueryProcessor { + client: TabletClient, +} + +impl From for PgWireError { + fn from(err: SqlError) -> PgWireError { + PgWireError::ApiError(Box::new(err)) + } +} + +impl NoopStartupHandler for PostgresqlQueryProcessor {} + +fn convert_schema(schema: &Schema) -> Vec { + schema + .fields + .iter() + .map(|field| { + FieldInfo::new(field.name().clone(), None, None, convert_data_type(field.data_type()), FieldFormat::Text) + }) + .collect() +} + +async fn try_send(sender: &mpsc::Sender>, result: PgWireResult) -> ControlFlow<()> { + match sender.send(result).await { + Ok(_) => ControlFlow::Continue(()), + Err(_) => ControlFlow::Break(()), + } +} + +async fn send_record_batch( + fields: &Arc>, + sender: &mpsc::Sender>, + result: DFResult, +) -> ControlFlow<()> { + let record = match result { + Err(err) => return try_send(sender, Err(PgWireError::ApiError(Box::new(err)))).await, + Ok(record) => record, + }; + let mut rows = vec![]; + rows.extend(std::iter::repeat_with(|| DataRowEncoder::new(fields.clone())).take(record.num_rows())); + let schema = record.schema(); + for (j, field) in schema.fields.iter().enumerate() { + let array = record.column(j); + match field.data_type() { + datatypes::DataType::Boolean => { + let array = array.as_any().downcast_ref::().unwrap(); + for (i, value) in array.values().iter().enumerate() { + let value = if array.is_null(i) { None } else { Some(value) }; + rows[i].encode_field(&value).unwrap(); + } + }, + datatypes::DataType::Int16 => { + let array = array.as_any().downcast_ref::().unwrap(); + for (i, value) in array.values().as_ref().iter().copied().enumerate() { + let value = if array.is_null(i) { None } else { Some(value) }; + rows[i].encode_field(&value).unwrap(); + } + }, + datatypes::DataType::Int32 => { + let array = array.as_any().downcast_ref::().unwrap(); + for (i, value) in array.values().as_ref().iter().copied().enumerate() { + let value = if array.is_null(i) { None } else { Some(value) }; + rows[i].encode_field(&value).unwrap(); + } + }, + datatypes::DataType::Int64 => { + let array = array.as_any().downcast_ref::().unwrap(); + for (i, value) in array.values().as_ref().iter().copied().enumerate() { + let value = if array.is_null(i) { None } else { Some(value) }; + rows[i].encode_field(&value).unwrap(); + } + }, + datatypes::DataType::UInt64 => { + let array = array.as_any().downcast_ref::().unwrap(); + for (i, value) in array.values().as_ref().iter().copied().enumerate() { + let value = if array.is_null(i) { None } else { Some(value as i64) }; + rows[i].encode_field(&value).unwrap(); + } + }, + datatypes::DataType::Utf8 => { + let array = array.as_any().downcast_ref::().unwrap(); + for (i, row) in rows.iter_mut().enumerate() { + let value = if array.is_null(i) { None } else { Some(array.value(i)) }; + row.encode_field(&value).unwrap(); + } + }, + datatypes::DataType::Binary => { + let array = array.as_any().downcast_ref::().unwrap(); + for (i, row) in rows.iter_mut().enumerate() { + let value = if array.is_null(i) { None } else { Some(array.value(i)) }; + row.encode_field(&value).unwrap(); + } + }, + datatypes::DataType::Float32 => { + let array = array.as_any().downcast_ref::().unwrap(); + for (i, value) in array.values().as_ref().iter().copied().enumerate() { + let value = if array.is_null(i) { None } else { Some(value) }; + rows[i].encode_field(&value).unwrap(); + } + }, + datatypes::DataType::Float64 => { + let array = array.as_any().downcast_ref::().unwrap(); + for (i, value) in array.values().as_ref().iter().copied().enumerate() { + let value = if array.is_null(i) { None } else { Some(value) }; + rows[i].encode_field(&value).unwrap(); + } + }, + data_type => panic!("unsupported data type {}", data_type), + } + } + for row in rows { + try_send(sender, row.finish()).await?; + } + ControlFlow::Continue(()) +} + +fn convert_record_batch_stream( + fields: Arc>, + mut stream: SendableRecordBatchStream, +) -> impl Stream> { + let (sender, receiver) = mpsc::channel(10); + tokio::spawn(async move { + while let Some(item) = stream.next().await { + if send_record_batch(&fields, &sender, item).await.is_break() { + break; + } + } + }); + ReceiverStream::new(receiver) +} + +fn convert_data_type(data_type: &datatypes::DataType) -> Type { + use datatypes::DataType::*; + match data_type { + Boolean => Type::BOOL, + Int16 | UInt16 => Type::INT2, + Int32 | UInt32 => Type::INT4, + Int64 | UInt64 => Type::INT8, + Utf8 => Type::TEXT, + Float32 => Type::FLOAT4, + Float64 => Type::FLOAT8, + _ => panic!("unsupported data type: {data_type}"), + } +} + +#[async_trait::async_trait] +impl SimpleQueryHandler for PostgresqlQueryProcessor { + async fn do_query<'a, C>(&self, client: &mut C, query: &'a str) -> PgWireResult>> + where + C: ClientInfo + Sink + Unpin + Send + Sync, + C::Error: Debug, + PgWireError: From<>::Error>, { + debug!("SQL: {query}"); + let database = client.metadata().get(METADATA_DATABASE).cloned().unwrap_or_else(|| "default".to_string()); + let executor = PostgreSqlExecutor::new(Arc::new(self.client.clone()), database); + let mut stream = executor.execute_sql(query).await?; + let schema = stream.schema(); + if schema.fields.is_empty() { + while let Some(_record) = stream.next().await {} + return Ok(vec![Response::EmptyQuery]); + } + let fields = Arc::new(convert_schema(&schema)); + let stream = convert_record_batch_stream(fields.clone(), stream); + let response = QueryResponse::new(fields, stream); + Ok(vec![Response::Query(response)]) + } +} + +#[derive(Clone)] +pub struct PostgresqlHandlerFactory { + processor: Arc, +} + +impl PostgresqlHandlerFactory { + pub fn new(client: TabletClient) -> Self { + Self { processor: Arc::new(PostgresqlQueryProcessor { client }) } + } +} + +impl PgWireHandlerFactory for PostgresqlHandlerFactory { + type CopyHandler = NoopCopyHandler; + type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; + type SimpleQueryHandler = PostgresqlQueryProcessor; + type StartupHandler = PostgresqlQueryProcessor; + + fn simple_query_handler(&self) -> Arc { + self.processor.clone() + } + + fn extended_query_handler(&self) -> Arc { + Arc::new(PlaceholderExtendedQueryHandler) + } + + fn startup_handler(&self) -> Arc { + self.processor.clone() + } + + fn copy_handler(&self) -> Arc { + Arc::new(NoopCopyHandler) + } +} + +#[cfg(test)] +mod tests { + use datafusion::execution::config::SessionConfig; + use datafusion::execution::session_state::{SessionState, SessionStateBuilder}; + use datafusion::logical_expr::LogicalPlan; + + use super::PostgresPlanner; + use crate::protos::{CharacterTypeDeclaration, ColumnTypeDeclaration, ColumnTypeKind}; + use crate::sql::plan::CreateTablePlan; + + fn session_state() -> SessionState { + let mut config = SessionConfig::default(); + config.options_mut().catalog.create_default_catalog_and_schema = false; + config.options_mut().catalog.default_catalog = "database1".to_string(); + config.options_mut().catalog.information_schema = true; + SessionStateBuilder::new().with_default_features().with_config(config).build() + } + + #[asyncs::test] + async fn create_table() { + let state = session_state(); + let planner = PostgresPlanner::new(&state); + let (plan, _tables) = planner + .plan( + r#"CREATE TABLE IF NOT EXISTS example ( + id bigserial PRIMARY KEY, + name varchar(40) NOT NULL, + description varchar, + CONSTRAINT unique_name UNIQUE(name) + );"#, + ) + .await + .unwrap(); + let LogicalPlan::Extension(extension) = plan else { panic!("expect create table plan, get {plan}") }; + let plan = extension.node.as_any().downcast_ref::().unwrap(); + assert!(plan.if_not_exists); + assert_eq!(plan.table_ref.table(), "example"); + + let descriptor = &plan.table_descriptor; + assert_eq!(descriptor.columns[0].id, 1); + assert_eq!(descriptor.columns[0].name, "id"); + assert!(descriptor.columns[0].serial); + assert!(!descriptor.columns[0].nullable); + assert_eq!(descriptor.columns[0].type_kind, ColumnTypeKind::Int64); + + assert_eq!(descriptor.columns[1].id, 2); + assert_eq!(descriptor.columns[1].name, "name"); + assert!(!descriptor.columns[1].nullable); + assert_eq!(descriptor.columns[1].type_kind, ColumnTypeKind::String); + assert_eq!( + descriptor.columns[1].type_declaration, + Some(ColumnTypeDeclaration::Character(CharacterTypeDeclaration { max_length: 40 })) + ); + + assert_eq!(descriptor.columns[2].id, 3); + assert_eq!(descriptor.columns[2].name, "description"); + assert!(descriptor.columns[2].nullable); + assert_eq!(descriptor.columns[2].type_kind, ColumnTypeKind::String); + assert_eq!(descriptor.columns[2].type_declaration, None); + + assert_eq!(descriptor.indices[0].id, 1); + assert!(descriptor.indices[0].unique); + assert_eq!(descriptor.indices[0].column_ids, vec![1]); + assert_eq!(descriptor.indices[0].storing_column_ids, vec![2, 3]); + + assert_eq!(descriptor.indices[1].id, 2); + assert!(descriptor.indices[1].unique); + assert_eq!(descriptor.indices[1].column_ids, vec![2]); + assert_eq!(descriptor.indices[1].storing_column_ids, vec![1]); + } +} diff --git a/src/sql/shared.rs b/src/sql/shared.rs new file mode 100644 index 0000000..3223879 --- /dev/null +++ b/src/sql/shared.rs @@ -0,0 +1,182 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::{self, Display, Formatter}; + +use super::error::SqlError; +use super::traits::*; +use crate::protos::{ColumnDescriptor, ColumnTypeDeclaration, ColumnTypeKind, IndexDescriptor, TableDescriptor}; + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +enum IndexKind { + Index, + Unique, + Primary, +} + +impl Display for IndexKind { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Index => f.write_str("index"), + Self::Unique => f.write_str("unique index"), + Self::Primary => f.write_str("primary index"), + } + } +} + +#[derive(PartialEq, Eq, PartialOrd, Ord)] +struct IndexMeta { + kind: IndexKind, + name: Option, + column_names: Vec, + column_ids: Vec, +} + +impl IndexMeta { + fn new( + kind: IndexKind, + name: Option, + column_names: Vec, + descriptor: &TableDescriptor, + ) -> Result { + let column_ids = column_names + .iter() + .map(|column_name| match descriptor.find_column(column_name.as_ref()) { + None => Err(SqlError::invalid(format!( + "table {} has no column in defining {} {:?}", + descriptor.name, kind, name + ))), + Some(column) => Ok(column.id), + }) + .collect::, SqlError>>()?; + Ok(Self { kind, name, column_names, column_ids }) + } +} + +pub struct TableDescriptorBuilder { + descriptor: TableDescriptor, + primary_index: Option, + unique_indices: Vec, + indices: Vec, +} + +pub struct ColumnDescriptorBuilder<'a> { + builder: &'a mut TableDescriptorBuilder, +} + +impl ColumnDescriptorBuilder<'_> { + fn column(&mut self) -> &mut ColumnDescriptor { + self.builder.descriptor.columns.last_mut().unwrap() + } + + pub fn set_nullable(&mut self, nullable: bool) { + self.column().nullable = nullable; + } + + pub fn set_serial(&mut self, serial: bool) { + let column = self.column(); + column.nullable = column.nullable && !serial; + column.serial = serial; + } + + pub fn add_unique_constraint(&mut self, primary: bool, name: Option) -> Result<(), SqlError> { + let column_names = vec![self.column().name.clone()]; + match primary { + true => self.builder.add_primary_index(name, column_names), + false => self.builder.add_unique_index(name, column_names), + } + } +} + +impl TableDescriptorBuilder { + pub fn new() -> Self { + Self { descriptor: TableDescriptor::default(), primary_index: None, unique_indices: vec![], indices: vec![] } + } + + pub fn add_column( + &mut self, + name: String, + type_kind: ColumnTypeKind, + type_declaration: Option, + ) -> Result, SqlError> { + if self.descriptor.find_column(&name).is_some() { + return Err(SqlError::invalid(format!("multiple columns named {}", name))); + } + self.descriptor.add_column(ColumnDescriptor { + id: 0, + name, + nullable: true, + type_kind, + type_declaration, + ..Default::default() + }); + Ok(ColumnDescriptorBuilder { builder: self }) + } + + pub fn add_unique_index(&mut self, name: Option, column_names: Vec) -> Result<(), SqlError> { + let index_meta = IndexMeta::new(IndexKind::Unique, name, column_names, &self.descriptor)?; + self.unique_indices.push(index_meta); + Ok(()) + } + + pub fn add_primary_index(&mut self, name: Option, column_names: Vec) -> Result<(), SqlError> { + if let Some(index) = self.primary_index.as_ref() { + return Err(SqlError::invalid(format!( + "multiple primary indices: name {:?}, columns {:?} and name {:?}, columns {:?}", + index.name, index.column_names, name, column_names + ))); + } + let index_meta = IndexMeta::new(IndexKind::Primary, name, column_names, &self.descriptor)?; + self.primary_index = Some(index_meta); + Ok(()) + } + + pub fn add_index(&mut self, name: Option, column_names: Vec) -> Result<(), SqlError> { + let index_meta = IndexMeta::new(IndexKind::Index, name, column_names, &self.descriptor)?; + self.indices.push(index_meta); + Ok(()) + } + + pub fn build(mut self) -> Result { + let Some(primary_index) = self.primary_index else { + return Err(SqlError::invalid(format!("table {} defines no index", self.descriptor.name))); + }; + self.descriptor.add_index(IndexDescriptor { + id: 0, + name: primary_index.name.unwrap_or_default(), + column_ids: primary_index.column_ids, + unique: true, + storing_column_ids: vec![], + }); + for unique_index in self.unique_indices.drain(0..) { + self.descriptor.add_index(IndexDescriptor { + id: 0, + name: unique_index.name.unwrap_or_default(), + column_ids: unique_index.column_ids, + unique: true, + storing_column_ids: vec![], + }); + } + for index in self.indices.drain(0..) { + self.descriptor.add_index(IndexDescriptor { + id: 0, + name: index.name.unwrap_or_default(), + column_ids: index.column_ids, + unique: true, + storing_column_ids: vec![], + }); + } + Ok(self.descriptor) + } +} diff --git a/src/sql/traits.rs b/src/sql/traits.rs new file mode 100644 index 0000000..6bdc87b --- /dev/null +++ b/src/sql/traits.rs @@ -0,0 +1,120 @@ +// Copyright 2023 The SeamDB Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::{HashMap, HashSet}; +use std::ops::ControlFlow; +use std::sync::Arc; + +use datafusion::catalog::TableProvider; +use datafusion::catalog_common::ResolvedTableReference; +use datafusion::common::DataFusionError; +use datafusion::execution::session_state::SessionState; +use datafusion::sql::planner::object_name_to_table_reference; +use datafusion::sql::sqlparser::ast::{visit_relations, Statement}; +use datafusion::sql::TableReference; + +use super::error::SqlError; +use crate::protos::{ColumnDescriptor, IndexDescriptor, TableDescriptor}; + +pub trait TableSchema { + fn add_column(&mut self, column: ColumnDescriptor) -> &mut ColumnDescriptor; + fn add_index(&mut self, index: IndexDescriptor); +} + +impl TableSchema for TableDescriptor { + fn add_column(&mut self, mut column: ColumnDescriptor) -> &mut ColumnDescriptor { + column.id = self.last_column_id + 1; + self.last_column_id = column.id; + self.columns.push(column); + self.columns.last_mut().unwrap() + } + + fn add_index(&mut self, mut index: IndexDescriptor) { + index.id = self.last_index_id + 1; + if index.storing_column_ids.is_empty() { + match self.indices.first() { + None => { + assert!(index.unique); + // This is the primary index. + index.storing_column_ids.extend( + self.columns + .iter() + .map(|column| column.id) + .filter(|column_id| !index.column_ids.iter().copied().any(|id| id == *column_id)), + ); + }, + Some(primary_index) => index.storing_column_ids.extend_from_slice(&primary_index.column_ids), + } + } + if index.name.is_empty() { + index.name = match (self.indices.len(), index.unique) { + (0, true) => format!("primary_index_{}", index.id), + (_, true) => format!("unique_index_{}", index.id), + (_, _) => format!("index_{}", index.id), + }; + } + self.last_index_id = index.id; + self.indices.push(index); + } +} + +#[async_trait::async_trait] +pub trait PlannerContext { + fn state(&self) -> &SessionState; + + fn collect_table_references(&self, statement: &Statement) -> Result, DataFusionError> { + let mut relations = HashSet::new(); + visit_relations(statement, |relation| { + relations.insert(relation.clone()); + ControlFlow::<()>::Continue(()) + }); + relations + .into_iter() + .map(|x| object_name_to_table_reference(x, true).map(|x| self.resolve_table_reference(x))) + .collect::>() + } + + fn resolve_table_reference(&self, name: TableReference) -> ResolvedTableReference { + let default_catalog = &self.state().config_options().catalog.default_catalog; + name.resolve(default_catalog, "public") + } + + async fn fetch_table_references( + &self, + table_references: Vec, + ) -> Result>, SqlError> { + let catalogs = self.state().catalog_list(); + let mut tables = HashMap::new(); + for table_ref in table_references { + let Some(catalog) = catalogs.catalog(&table_ref.catalog) else { + continue; + }; + let Some(schema) = catalog.schema(&table_ref.schema) else { + continue; + }; + let Some(table) = schema.table(&table_ref.table).await? else { + continue; + }; + tables.insert(table_ref.into(), table); + } + Ok(tables) + } +} + +#[async_trait::async_trait] +impl PlannerContext for SessionState { + fn state(&self) -> &SessionState { + self + } +} diff --git a/src/tablet/client.rs b/src/tablet/client.rs index 2f4676a..1c63c23 100644 --- a/src/tablet/client.rs +++ b/src/tablet/client.rs @@ -17,26 +17,35 @@ use std::sync::Arc; use std::time::Duration; use anyhow::anyhow; +use enum_dispatch::enum_dispatch; use prost::Message as _; use thiserror::Error; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tonic::Status; -use tracing::debug; +use tracing::{debug, trace}; use crate::cluster::{ClusterEnv, NodeId}; use crate::keys; use crate::protos::{ + BatchError, BatchRequest, + BatchResponse, + DataError, DataRequest, DataResponse, FindRequest, FindResponse, GetRequest, + GetResponse, IncrementRequest, + KeyRange, + KeySpan, ParticipateTxnRequest, ParticipateTxnResponse, PutRequest, + RefreshReadRequest, + ScanRequest, ShardDescriptor, ShardId, ShardRequest, @@ -46,7 +55,7 @@ use crate::protos::{ TabletServiceClient, Temporal, Timestamp, - TimestampedValue, + TimestampedKeyValue, Transaction, TxnMeta, Uuid, @@ -54,14 +63,299 @@ use crate::protos::{ }; // TODO: cache and invalidate on error -#[derive(Clone)] -pub struct TabletClient { +struct RootTabletClient { root: Arc, descriptor: Arc, deployment: Arc, cluster: ClusterEnv, } +#[derive(Clone)] +struct ScopedTabletClient { + prefix: Vec, + client: Arc, +} + +#[enum_dispatch] +#[derive(Clone)] +enum InnerTabletClient { + Root(Arc), + Scoped(ScopedTabletClient), +} + +#[enum_dispatch(InnerTabletClient)] +trait TabletClientTrait { + fn now(&self) -> Timestamp; + + fn prefix(&self) -> &[u8]; + + fn root(&self) -> &RootTabletClient; + + #[allow(clippy::uninit_vec)] + fn prefix_key<'a>(&self, key: impl Into>) -> Vec { + let key = key.into(); + let prefix = self.prefix(); + if prefix.is_empty() { + return key.into_owned(); + } + match key { + Cow::Borrowed(key) => { + let mut prefixed_key = Vec::with_capacity(prefix.len() + key.len()); + prefixed_key.extend_from_slice(prefix); + prefixed_key.extend_from_slice(key); + prefixed_key + }, + Cow::Owned(mut key) => { + let key_len = key.len(); + let prefix_len = prefix.len(); + key.reserve(prefix_len); + unsafe { + key.set_len(key_len + prefix_len); + std::ptr::copy(key.as_ptr(), key.as_mut_ptr().wrapping_add(prefix_len), key_len); + std::ptr::copy_nonoverlapping(prefix.as_ptr(), key.as_mut_ptr(), prefix.len()); + } + key + }, + } + } + + fn prefix_span(&self, span: KeySpan) -> KeySpan { + let key = self.prefix_key(span.key); + let end = if span.end.is_empty() { span.end } else { self.prefix_key(span.end) }; + KeySpan { key, end } + } + + fn prefix_range(&self, range: KeyRange) -> KeyRange { + let start = self.prefix_key(range.start); + let end = self.prefix_key(range.end); + KeyRange { start, end } + } + + fn unprefix_key(&self, mut key: Vec) -> Result, Vec> { + let Some(stripped) = key.strip_prefix(self.prefix()) else { + return Err(key); + }; + let (ptr, len) = (stripped.as_ptr(), stripped.len()); + unsafe { + std::ptr::copy(ptr, key.as_mut_ptr(), len); + key.set_len(len); + } + Ok(key) + } + + fn new_transaction(&self, key: Vec) -> Transaction { + let key = self.prefix_key(key); + let meta = TxnMeta { id: Uuid::new_random(), key, epoch: 0, start_ts: self.now(), priority: 0 }; + Transaction { meta, ..Default::default() } + } + + async fn heartbeat_txn(&self, txn: Transaction) -> Result { + self.root().heartbeat_txn(txn).await + } + + async fn get_directly( + &self, + temporal: Temporal, + key: impl Into>, + sequence: u32, + ) -> Result<(Temporal, Option)> { + let key = self.prefix_key(key); + self.root().get_directly(temporal, key, sequence).await + } + + async fn get(&self, key: impl Into>) -> Result> { + match self.get_directly(Temporal::default(), key, 0).await? { + (_temporal, None) => Err(TabletClientError::unexpected("timestamp get get no response")), + (_temporal, Some(response)) => Ok(response.value.map(|v| v.into_parts())), + } + } + + async fn put_directly( + &self, + temporal: Temporal, + key: impl Into>, + value: Option, + expect_ts: Option, + sequence: u32, + ) -> Result<(Temporal, Option)> { + let key = self.prefix_key(key); + self.root().put_directly(temporal, key, value, expect_ts, sequence).await + } + + async fn put( + &self, + key: impl Into>, + value: Option, + expect_ts: Option, + ) -> Result { + match self.put_directly(Temporal::default(), key, value, expect_ts, 0).await? { + (_temporal, None) => Err(TabletClientError::unexpected("timestamp put get aborted")), + (_temporal, Some(write_ts)) => Ok(write_ts), + } + } + + async fn refresh_read( + &self, + temporal: Temporal, + span: KeySpan, + from: Timestamp, + ) -> Result<(Temporal, Option>)> { + let span = self.prefix_span(span); + let (temporal, resume_key) = self.root().refresh_read(temporal, span, from).await?; + let resume_key = match resume_key { + Some(resume_key) if !resume_key.is_empty() => Some(self.unprefix_key(resume_key).unwrap()), + _ => resume_key, + }; + Ok((temporal, resume_key)) + } + + async fn increment_directly( + &self, + temporal: Temporal, + key: impl Into>, + increment: i64, + sequence: u32, + ) -> Result<(Temporal, Option)> { + let key = self.prefix_key(key); + self.root().increment_directly(temporal, key, increment, sequence).await + } + + async fn increment(&self, key: impl Into>, increment: i64) -> Result { + match self.increment_directly(Temporal::default(), key, increment, 0).await? { + (_temporal, None) => Err(TabletClientError::unexpected("timestamp increment get aborted")), + (_temporal, Some(incremented)) => Ok(incremented), + } + } + + async fn scan_directly( + &self, + temporal: Temporal, + start: impl Into>, + end: impl Into>, + limit: u32, + ) -> Result<(Temporal, Option<(Vec, Vec)>)> { + let start = self.prefix_key(start); + let end = self.prefix_key(end); + match self.root().scan_directly(temporal, start, end, limit).await? { + (temporal, None) => Ok((temporal, None)), + (temporal, Some((resume_key, mut rows))) => { + let resume_key = + if resume_key.is_empty() { resume_key } else { self.unprefix_key(resume_key).unwrap() }; + for row in rows.iter_mut() { + row.key = self.unprefix_key(std::mem::take(&mut row.key)).unwrap(); + } + Ok((temporal, Some((resume_key, rows)))) + }, + } + } + + async fn scan( + &self, + start: impl Into>, + end: impl Into>, + limit: u32, + ) -> Result<(Vec, Vec)> { + match self.scan_directly(Temporal::default(), start, end, limit).await? { + (_temporal, None) => Err(TabletClientError::unexpected("no scan response")), + (_temporal, Some(response)) => Ok(response), + } + } + + async fn locate(&self, key: impl Into>) -> Result { + let key = self.prefix_key(key); + self.root().locate(key).await + } + + async fn participate_txn( + &self, + request: ParticipateTxnRequest, + coordinator: bool, + ) -> Result<(mpsc::Sender, tonic::Streaming)> { + self.root().participate_txn(request, coordinator).await + } + + async fn open_participate_txn( + &self, + request: ParticipateTxnRequest, + coordinator: bool, + ) -> (mpsc::Sender, tonic::Streaming) { + self.root().open_participate_txn(request, coordinator).await + } + + async fn service(&self, key: &[u8]) -> Result<(ShardDeployment, TabletServiceClient)> { + let key = self.prefix_key(key); + self.root().service(&key).await + } + + async fn batch(&self, mut requests: Vec) -> Result> { + if self.prefix().is_empty() { + return self.root().batch(requests).await; + } + for request in requests.iter_mut() { + match request { + DataRequest::Get(get) => { + get.key = self.prefix_key(Cow::Owned(std::mem::take(&mut get.key))); + }, + DataRequest::Put(put) => { + put.key = self.prefix_key(Cow::Owned(std::mem::take(&mut put.key))); + }, + DataRequest::Increment(increment) => { + increment.key = self.prefix_key(Cow::Owned(std::mem::take(&mut increment.key))); + }, + DataRequest::Find(find) => { + find.key = self.prefix_key(Cow::Owned(std::mem::take(&mut find.key))); + }, + DataRequest::Scan(scan) => { + scan.range = self.prefix_range(std::mem::take(&mut scan.range)); + }, + DataRequest::RefreshRead(refresh_read) => { + refresh_read.span = self.prefix_span(std::mem::take(&mut refresh_read.span)); + }, + } + } + let mut responses = self.root().batch(requests).await?; + for response in responses.iter_mut() { + match response { + DataResponse::Get(_) => {}, + DataResponse::Put(_) => {}, + DataResponse::Increment(_) => {}, + DataResponse::Find(find) => { + if !find.key.is_empty() { + find.key = self.unprefix_key(std::mem::take(&mut find.key)).unwrap(); + } + }, + DataResponse::Scan(scan) => { + if !scan.resume_key.is_empty() { + scan.resume_key = self.unprefix_key(std::mem::take(&mut scan.resume_key)).unwrap(); + } + for row in scan.rows.iter_mut() { + row.key = self.unprefix_key(std::mem::take(&mut row.key)).unwrap(); + } + }, + DataResponse::RefreshRead(refresh_read) => { + if !refresh_read.resume_key.is_empty() { + refresh_read.resume_key = + self.unprefix_key(std::mem::take(&mut refresh_read.resume_key)).unwrap(); + } + }, + } + } + Ok(responses) + } + + async fn find(&self, key: &[u8]) -> Result, Value)>> { + let key = self.prefix_key(key); + let Some((ts, key, value)) = self.root().find(key).await? else { + return Ok(None); + }; + Ok(Some((ts, self.unprefix_key(key).unwrap(), value))) + } + + async fn get_tablet_descriptor(&self, id: TabletId) -> Result<(Timestamp, TabletDescriptor)> { + self.root().get_tablet_descriptor(id).await + } +} + #[derive(Debug, Error)] pub enum TabletClientError { #[error("cluster not ready")] @@ -84,6 +378,10 @@ pub enum TabletClientError { ShardNotFound { tablet_id: TabletId, shard_id: ShardId, key: Vec }, #[error("data corruption: {message}")] DataCorruption { message: String }, + #[error("key {key:?} already exist")] + KeyAlreadyExists { key: Vec }, + #[error("key {key:?} get overwritten at {actual}")] + KeyTimestampMismatch { key: Vec, actual: Timestamp }, #[error("invalid argument: {message}")] InvalidArgument { message: String }, #[error(transparent)] @@ -113,6 +411,7 @@ impl From for tonic::Status { Status::internal(err.to_string()) }, TabletClientError::UnexpectedError { .. } => Status::unknown(err.to_string()), + TabletClientError::KeyAlreadyExists { .. } | TabletClientError::KeyTimestampMismatch { .. } => todo!(), } } } @@ -169,7 +468,35 @@ impl ShardDeployment { } } -impl TabletClient { +impl TabletClientTrait for RootTabletClient { + fn now(&self) -> Timestamp { + self.cluster.clock().now() + } + + fn prefix(&self) -> &[u8] { + Default::default() + } + + fn root(&self) -> &RootTabletClient { + self + } +} + +impl TabletClientTrait for Arc { + fn now(&self) -> Timestamp { + self.cluster.clock().now() + } + + fn prefix(&self) -> &[u8] { + Default::default() + } + + fn root(&self) -> &RootTabletClient { + self.as_ref() + } +} + +impl RootTabletClient { pub fn new(cluster: ClusterEnv) -> Self { Self { cluster, @@ -179,15 +506,6 @@ impl TabletClient { } } - pub fn now(&self) -> Timestamp { - self.cluster.clock().now() - } - - pub fn new_transaction(&self, key: Vec) -> Transaction { - let meta = TxnMeta { id: Uuid::new_random(), key, epoch: 0, start_ts: self.cluster.clock().now(), priority: 0 }; - Transaction { meta, ..Default::default() } - } - fn get_cluster_descriptor(&self) -> Result<(Timestamp, TabletDescriptor)> { let Some(descriptor) = self.cluster.latest_descriptor() else { return Err(TabletClientError::ClusterNotReady); @@ -266,34 +584,26 @@ impl TabletClient { } async fn get_shard(&self, deployment: &ShardDeployment, key: impl Into>) -> Result> { - let Some(node) = deployment.node_id() else { - return Err(TabletClientError::TabletNotDeployed { id: deployment.tablet_id() }); - }; - let Some(addr) = self.cluster.nodes().get_endpoint(node) else { - return Err(TabletClientError::NodeNotAvailable { node: node.clone() }); - }; - let mut client = TabletServiceClient::connect(addr.to_string()) - .await - .map_err(|e| TabletClientError::NodeNotConnectable { node: node.clone(), message: e.to_string() })?; let key = key.into(); - let batch = BatchRequest { - tablet_id: deployment.tablet_id().into(), - uncertainty: None, - temporal: Temporal::default(), - requests: vec![ShardRequest { - shard_id: deployment.shard_id().into(), - request: DataRequest::Find(FindRequest { key, sequence: 0 }), - }], - }; - let response = client.batch(batch).await?.into_inner(); - let find = response + let find = self + .batch_request(deployment, BatchRequest { + tablet_id: deployment.tablet_id().into(), + uncertainty: None, + temporal: Temporal::default(), + requests: vec![ShardRequest { + shard_id: deployment.shard_id().into(), + request: DataRequest::Find(FindRequest { key: key.clone(), sequence: 0 }), + }], + }) + .await? .into_find() .map_err(|r| TabletClientError::unexpected(format!("unexpected find response: {:?}", r)))?; + let FindResponse { key: located_key, value: Some(value) } = find else { return Err(TabletClientError::ShardNotFound { tablet_id: deployment.tablet_id(), shard_id: deployment.shard_id(), - key: find.key, + key, }); }; let bytes = value @@ -350,9 +660,9 @@ impl TabletClient { let Some(addr) = self.cluster.nodes().get_endpoint(node) else { return Err(TabletClientError::NodeNotAvailable { node: node.clone() }); }; - let service = - TabletServiceClient::connect(addr.to_string()).await.map_err(|e| Status::unavailable(e.to_string()))?; - Ok(service) + TabletServiceClient::connect(addr.to_string()) + .await + .map_err(|e| TabletClientError::NodeNotConnectable { node: node.clone(), message: e.to_string() }) } pub async fn service( @@ -364,21 +674,33 @@ impl TabletClient { Ok((deployment, service)) } - async fn request(&self, deployment: &ShardDeployment, request: DataRequest) -> Result { + async fn batch_request(&self, deployment: &ShardDeployment, request: BatchRequest) -> Result { + let tablet_id = request.tablet_id; let mut service = self.connect(deployment).await?; - let tablet_id = deployment.tablet_id().into(); - let shard_id = deployment.shard_id().into(); - let batch = BatchRequest { - tablet_id, - uncertainty: None, - temporal: Temporal::default(), - requests: vec![ShardRequest { shard_id, request }], - }; - let response = service.batch(batch).await?.into_inner(); - let response = response - .into_one() - .map_err(|r| TabletClientError::unexpected(format!("expect one response, got {:?}", r)))?; - Ok(response.response) + match service.batch(request).await { + Ok(response) => Ok(response.into_inner()), + Err(status) => { + let details = status.details(); + if !details.is_empty() { + let err = BatchError::decode(details).unwrap(); + trace!("batch error: {err:?}"); + let err = match err.error { + DataError::ConflictWrite(err) => TabletClientError::KeyAlreadyExists { key: err.key }, + DataError::ShardNotFound(err) => TabletClientError::ShardNotFound { + tablet_id: tablet_id.into(), + shard_id: err.shard_id.into(), + key: err.key, + }, + DataError::TimestampMismatch(err) => { + TabletClientError::KeyTimestampMismatch { key: err.key, actual: err.actual } + }, + err => TabletClientError::Internal(anyhow!("{err}")), + }; + return Err(err); + } + Err(TabletClientError::from(status)) + }, + } } async fn request_batch( @@ -386,14 +708,6 @@ impl TabletClient { deployment: &ShardDeployment, requests: Vec, ) -> Result> { - let Some(node) = deployment.node_id() else { - return Err(TabletClientError::TabletNotDeployed { id: deployment.tablet_id() }); - }; - let Some(addr) = self.cluster.nodes().get_endpoint(node) else { - return Err(TabletClientError::NodeNotAvailable { node: node.clone() }); - }; - let mut client = - TabletServiceClient::connect(addr.to_string()).await.map_err(|e| Status::unavailable(e.to_string()))?; let n = requests.len(); let batch = BatchRequest { tablet_id: deployment.tablet.id, @@ -401,7 +715,7 @@ impl TabletClient { temporal: Temporal::default(), requests, }; - let response = client.batch(batch).await?.into_inner(); + let response = self.batch_request(deployment, batch).await?; if response.responses.len() != n { return Err(TabletClientError::unexpected(format!("unexpected responses: {:?}", response))); } @@ -429,152 +743,217 @@ impl TabletClient { self.request_batch(&deployment, requests).await } + async fn get_internally( + &self, + deployment: &ShardDeployment, + temporal: Temporal, + key: Vec, + sequence: u32, + ) -> Result<(Temporal, GetResponse)> { + let mut response = self + .batch_request(deployment, BatchRequest { + tablet_id: deployment.tablet_id().into(), + uncertainty: None, + temporal, + requests: vec![ShardRequest { + shard_id: deployment.shard_id().into(), + request: DataRequest::Get(GetRequest { key, sequence }), + }], + }) + .await?; + let get_response = response + .responses + .remove(0) + .response + .into_get() + .map_err(|r| TabletClientError::unexpected(format!("expect get response, got {r:?}")))?; + Ok((response.temporal, get_response)) + } + async fn raw_get( &self, deployment: &ShardDeployment, key: impl Into>, ) -> Result> { - let get = GetRequest { key: key.into(), sequence: 0 }; - let response = self.request(deployment, DataRequest::Get(get)).await?; - let response = response.into_get().map_err(|r| anyhow!("expect get response, get {:?}", r))?; + let (_, response) = self.get_internally(deployment, Temporal::default(), key.into(), 0).await?; Ok(response.value.map(|v| (v.timestamp, v.value))) } - pub async fn get(&self, key: impl Into>) -> Result> { - let key = key.into().into_owned(); - let deployment = self.locate(&key).await?; - let get = GetRequest { key, sequence: 0 }; - let response = self.request(&deployment, DataRequest::Get(get)).await?; - let response = response.into_get().map_err(|r| anyhow!("expect get response, get {:?}", r))?; - Ok(response.value.map(|v| (v.timestamp, v.value))) + pub async fn heartbeat_txn(&self, txn: Transaction) -> Result { + let (shard, mut service) = self.service(&txn.meta.key).await?; + let response = service + .batch(BatchRequest { + tablet_id: shard.tablet_id().into(), + temporal: Temporal::Transaction(txn), + ..Default::default() + }) + .await? + .into_inner(); + Ok(response.temporal.into_transaction()) } - async fn put_internally( + pub async fn get_directly( &self, - key: Cow<'_, [u8]>, - value: Option, - expect_ts: Option, - ) -> Result { - let key = key.into_owned(); - let deployment = self.locate(&key).await?; - let put = PutRequest { key, value, sequence: 0, expect_ts }; - let response = self.request(&deployment, DataRequest::Put(put)).await?; - let response = response.into_put().map_err(|r| anyhow!("expect put response, get {:?}", r))?; - Ok(response.write_ts) + temporal: Temporal, + key: impl Into>, + sequence: u32, + ) -> Result<(Temporal, Option)> { + let key = key.into().into_owned(); + let shard = self.locate(&key).await?; + self.get_internally(&shard, temporal, key, sequence) + .await + .map(|(temporal, response)| (temporal, Some(response))) } - pub async fn delete(&self, key: impl Into>, expect_ts: Option) -> Result<()> { - self.put_internally(key.into(), None, expect_ts).await?; - Ok(()) + async fn find(&self, key: Vec) -> Result, Value)>> { + let shard = self.locate(&key).await?; + let batch = BatchRequest { + tablet_id: shard.tablet_id().into(), + uncertainty: None, + temporal: Temporal::default(), + requests: vec![ShardRequest { + shard_id: shard.shard_id().into(), + request: DataRequest::Find(FindRequest { key, sequence: 0 }), + }], + }; + let mut response = self.batch_request(&shard, batch).await?; + if response.responses.is_empty() { + return Err(TabletClientError::unexpected("find get no response")); + } + let find_response = response + .responses + .remove(0) + .response + .into_find() + .map_err(|r| TabletClientError::unexpected(format!("expect find response, got {r:?}")))?; + Ok(find_response.value.map(|v| (v.timestamp, find_response.key, v.value))) } - pub async fn put( + pub async fn put_directly( &self, + temporal: Temporal, key: impl Into>, - value: Value, + value: Option, expect_ts: Option, - ) -> Result { - self.put_internally(key.into(), Some(value), expect_ts).await - } - - pub async fn increment(&self, key: impl Into>, increment: i64) -> Result { - let key = key.into(); - let deployment = self.locate(key.as_ref()).await?; - let increment = IncrementRequest { key: key.to_vec(), increment, sequence: 0 }; - let response = self.request(&deployment, DataRequest::Increment(increment)).await?; - let response = response.into_increment().map_err(|r| anyhow!("expect increment response, get {:?}", r))?; - Ok(response.value) - } - - pub async fn find(&self, key: &[u8]) -> Result, Value)>> { - let user_key = keys::user_key(key); - let deployment = self.locate(&user_key).await?; - let find = FindRequest { key: user_key, sequence: 0 }; - let response = self.request(&deployment, DataRequest::Find(find)).await?; - let response = response.into_find().map_err(|r| anyhow!("expect find response, get {:?}", r))?; - match response.value { - None => Ok(None), - Some(value) => { - let mut key = response.key; - key.drain(0..keys::USER_KEY_PREFIX.len()); - Ok(Some((value.timestamp, key, value.value))) - }, + sequence: u32, + ) -> Result<(Temporal, Option)> { + let key = key.into().into_owned(); + let shard = self.locate(&key).await?; + let batch = BatchRequest { + tablet_id: shard.tablet_id().into(), + uncertainty: None, + temporal, + requests: vec![ShardRequest { + shard_id: shard.shard_id().into(), + request: DataRequest::Put(PutRequest { key, value, expect_ts, sequence }), + }], + }; + let mut response = self.batch_request(&shard, batch).await?; + if response.responses.is_empty() { + return Ok((response.temporal, None)); } + let put_response = response + .responses + .remove(0) + .response + .into_put() + .map_err(|r| TabletClientError::unexpected(format!("expect put response, got {r:?}")))?; + Ok((response.temporal, Some(put_response.write_ts))) } - pub async fn transactional_get( + pub async fn refresh_read( &self, - txn: Transaction, - key: &[u8], - sequence: u32, - ) -> Result<(Transaction, Option)> { - let (shard, mut service) = self.service(key).await?; - let mut response = service - .batch(BatchRequest { - tablet_id: shard.tablet_id().into(), - temporal: Temporal::Transaction(txn), - requests: vec![ShardRequest { - shard_id: shard.shard_id().into(), - request: DataRequest::Get(GetRequest { key: key.to_owned(), sequence }), - }], - ..Default::default() - }) - .await? - .into_inner(); - let txn = std::mem::take(&mut response.temporal).into_transaction(); - let get = response.into_get().map_err(|r| anyhow!("expect get response, get {r:?}"))?; - Ok((txn, get.value)) + temporal: Temporal, + span: KeySpan, + from: Timestamp, + ) -> Result<(Temporal, Option>)> { + let shard = self.locate(&span.key).await?; + let batch = BatchRequest { + tablet_id: shard.tablet_id().into(), + uncertainty: None, + temporal, + requests: vec![ShardRequest { + shard_id: shard.shard_id().into(), + request: DataRequest::RefreshRead(RefreshReadRequest { span, from }), + }], + }; + let mut response = self.batch_request(&shard, batch).await?; + if response.responses.is_empty() { + return Ok((response.temporal, None)); + } + let refresh_read_response = response + .responses + .remove(0) + .response + .into_refresh_read() + .map_err(|r| TabletClientError::unexpected(format!("expect refresh read response, got {r:?}")))?; + Ok((response.temporal, Some(refresh_read_response.resume_key))) } - pub async fn transactional_put( + pub async fn increment_directly( &self, - txn: Transaction, - key: &[u8], - value: Option, + temporal: Temporal, + key: impl Into>, + increment: i64, sequence: u32, - expect_ts: Option, - ) -> Result { - let (shard, mut service) = self.service(key).await?; - let mut response = service - .batch(BatchRequest { - tablet_id: shard.tablet_id().into(), - temporal: Temporal::Transaction(txn), - requests: vec![ShardRequest { - shard_id: shard.shard_id().into(), - request: DataRequest::Put(PutRequest { key: key.to_owned(), value, sequence, expect_ts }), - }], - ..Default::default() - }) - .await? - .into_inner(); - let txn = std::mem::take(&mut response.temporal).into_transaction(); - response.into_put().map_err(|r| anyhow!("expect put response, get {r:?}"))?; - Ok(txn) + ) -> Result<(Temporal, Option)> { + let key = key.into().into_owned(); + let shard = self.locate(&key).await?; + let batch = BatchRequest { + tablet_id: shard.tablet_id().into(), + uncertainty: None, + temporal, + requests: vec![ShardRequest { + shard_id: shard.shard_id().into(), + request: DataRequest::Increment(IncrementRequest { key, increment, sequence }), + }], + }; + let mut response = self.batch_request(&shard, batch).await?; + if response.responses.is_empty() { + return Ok((response.temporal, None)); + } + let increment_response = response + .responses + .remove(0) + .response + .into_increment() + .map_err(|r| TabletClientError::unexpected(format!("expect put response, got {r:?}")))?; + Ok((response.temporal, Some(increment_response.value))) } - pub async fn transactional_increment( + pub async fn scan_directly( &self, - txn: Transaction, - key: &[u8], - increment: i64, - sequence: u32, - ) -> Result<(Transaction, i64)> { - let (shard, mut service) = self.service(key).await?; - let mut response = service - .batch(BatchRequest { - tablet_id: shard.tablet_id().into(), - temporal: Temporal::Transaction(txn), - requests: vec![ShardRequest { - shard_id: shard.shard_id().into(), - request: DataRequest::Increment(IncrementRequest { key: key.to_owned(), increment, sequence }), - }], - ..Default::default() - }) - .await? - .into_inner(); - let txn = std::mem::take(&mut response.temporal).into_transaction(); - let increment = response.into_increment().map_err(|r| anyhow!("expect increment response, get {r:?}"))?; - Ok((txn, increment.value)) + temporal: Temporal, + start: impl Into>, + end: impl Into>, + limit: u32, + ) -> Result<(Temporal, Option<(Vec, Vec)>)> { + let start = start.into().into_owned(); + let shard = self.locate(&start).await?; + let batch = BatchRequest { + tablet_id: shard.tablet_id().into(), + uncertainty: None, + temporal, + requests: vec![ShardRequest { + shard_id: shard.shard_id().into(), + request: DataRequest::Scan(ScanRequest { + range: KeyRange { start, end: end.into().into_owned() }, + limit, + sequence: 0, + }), + }], + }; + let mut response = self.batch_request(&shard, batch).await?; + if response.responses.is_empty() { + return Ok((response.temporal, None)); + } + let scan_response = response + .responses + .remove(0) + .response + .into_scan() + .map_err(|r| TabletClientError::unexpected(format!("expect scan response, got {r:?}")))?; + Ok((response.temporal, Some(scan_response.into_parts()))) } pub async fn open_participate_txn( @@ -612,6 +991,200 @@ impl TabletClient { } } +impl TabletClientTrait for ScopedTabletClient { + fn now(&self) -> Timestamp { + self.client.now() + } + + fn prefix(&self) -> &[u8] { + &self.prefix + } + + fn root(&self) -> &RootTabletClient { + &self.client + } +} + +impl ScopedTabletClient { + pub fn scope(mut self, prefix: &[u8]) -> Self { + self.prefix.extend_from_slice(prefix); + self + } +} + +impl From for TabletClient { + fn from(client: ScopedTabletClient) -> Self { + Self { inner: InnerTabletClient::Scoped(client) } + } +} + +impl From> for ScopedTabletClient { + fn from(client: Arc) -> Self { + ScopedTabletClient { prefix: Default::default(), client } + } +} + +#[derive(Clone)] +pub struct TabletClient { + inner: InnerTabletClient, +} + +impl TabletClient { + pub fn new(cluster: ClusterEnv) -> Self { + let client = RootTabletClient::new(cluster); + Self { inner: InnerTabletClient::Root(client.into()) } + } + + fn scoped_client(self) -> ScopedTabletClient { + match self.inner { + InnerTabletClient::Root(client) => ScopedTabletClient { prefix: vec![], client }, + InnerTabletClient::Scoped(client) => client, + } + } + + pub fn scope(self, prefix: &[u8]) -> Self { + self.scoped_client().scope(prefix).into() + } + + pub fn now(&self) -> Timestamp { + self.inner.now() + } + + pub fn prefix(&self) -> &[u8] { + self.inner.prefix() + } + + pub fn prefix_key<'a>(&self, key: impl Into>) -> Vec { + self.inner.prefix_key(key) + } + + pub fn unprefix_key(&self, key: Vec) -> Result, Vec> { + self.inner.unprefix_key(key) + } + + pub fn new_transaction(&self, key: Vec) -> Transaction { + self.inner.new_transaction(key) + } + + pub async fn heartbeat_txn(&self, txn: Transaction) -> Result { + self.inner.heartbeat_txn(txn).await + } + + pub async fn get_directly( + &self, + temporal: Temporal, + key: impl Into>, + sequence: u32, + ) -> Result<(Temporal, Option)> { + self.inner.get_directly(temporal, key, sequence).await + } + + pub async fn put_directly( + &self, + temporal: Temporal, + key: impl Into>, + value: Option, + expect_ts: Option, + sequence: u32, + ) -> Result<(Temporal, Option)> { + self.inner.put_directly(temporal, key, value, expect_ts, sequence).await + } + + pub async fn refresh_read( + &self, + temporal: Temporal, + span: KeySpan, + from: Timestamp, + ) -> Result<(Temporal, Option>)> { + self.inner.refresh_read(temporal, span, from).await + } + + pub async fn increment_directly( + &self, + temporal: Temporal, + key: impl Into>, + increment: i64, + sequence: u32, + ) -> Result<(Temporal, Option)> { + self.inner.increment_directly(temporal, key, increment, sequence).await + } + + pub async fn scan_directly( + &self, + temporal: Temporal, + start: impl Into>, + end: impl Into>, + limit: u32, + ) -> Result<(Temporal, Option<(Vec, Vec)>)> { + self.inner.scan_directly(temporal, start, end, limit).await + } + + pub async fn get(&self, key: impl Into>) -> Result> { + self.inner.get(key).await + } + + pub async fn put( + &self, + key: impl Into>, + value: Option, + expect_ts: Option, + ) -> Result { + self.inner.put(key, value, expect_ts).await + } + + pub async fn increment(&self, key: impl Into>, increment: i64) -> Result { + self.inner.increment(key, increment).await + } + + pub async fn scan( + &self, + start: impl Into>, + end: impl Into>, + limit: u32, + ) -> Result<(Vec, Vec)> { + self.inner.scan(start, end, limit).await + } + + pub async fn locate(&self, key: impl Into>) -> Result { + self.inner.locate(key).await + } + + pub async fn participate_txn( + &self, + request: ParticipateTxnRequest, + coordinator: bool, + ) -> Result<(mpsc::Sender, tonic::Streaming)> { + self.inner.participate_txn(request, coordinator).await + } + + pub async fn open_participate_txn( + &self, + request: ParticipateTxnRequest, + coordinator: bool, + ) -> (mpsc::Sender, tonic::Streaming) { + self.inner.open_participate_txn(request, coordinator).await + } + + pub async fn service( + &self, + key: &[u8], + ) -> Result<(ShardDeployment, TabletServiceClient)> { + self.inner.service(key).await + } + + pub async fn batch(&self, requests: Vec) -> Result> { + self.inner.batch(requests).await + } + + pub async fn find(&self, key: &[u8]) -> Result, Value)>> { + self.inner.find(key).await + } + + pub async fn get_tablet_descriptor(&self, id: TabletId) -> Result<(Timestamp, TabletDescriptor)> { + self.inner.get_tablet_descriptor(id).await + } +} + #[cfg(test)] mod tests { use std::pin::pin; @@ -675,29 +1248,26 @@ mod tests { tokio::time::sleep(Duration::from_secs(20)).await; - let client = TabletClient::new(cluster_env); + let client = TabletClient::new(cluster_env).scope(keys::USER_KEY_PREFIX); - let count = client.increment(keys::user_key(b"count"), 5).await.unwrap(); + let count = client.increment(b"count", 5).await.unwrap(); assert_eq!(count, 5); - let count = client.increment(keys::user_key(b"count"), 5).await.unwrap(); + let count = client.increment(b"count", 5).await.unwrap(); assert_eq!(count, 10); - let put_ts = client.put(keys::user_key(b"k1"), Value::Bytes(b"v1_1".to_vec()), None).await.unwrap(); - let (get_ts, value) = client.get(keys::user_key(b"k1")).await.unwrap().unwrap(); + let put_ts = client.put(b"k1", Some(Value::Bytes(b"v1_1".to_vec())), None).await.unwrap(); + let (get_ts, value) = client.get(b"k1").await.unwrap().unwrap(); assert_that!(get_ts).is_equal_to(put_ts); assert_that!(value.into_bytes().unwrap()).is_equal_to(b"v1_1".to_vec()); - let put_ts = client.put(keys::user_key(b"k1"), Value::Bytes(b"v1_2".to_vec()), Some(put_ts)).await.unwrap(); - let (get_ts, value) = client.get(keys::user_key(b"k1")).await.unwrap().unwrap(); + let put_ts = client.put(b"k1", Some(Value::Bytes(b"v1_2".to_vec())), Some(put_ts)).await.unwrap(); + let (get_ts, value) = client.get(b"k1").await.unwrap().unwrap(); assert_that!(get_ts).is_equal_to(put_ts); assert_that!(value.into_bytes().unwrap()).is_equal_to(b"v1_2".to_vec()); - client.delete(keys::user_key(b"k1"), Some(put_ts)).await.unwrap(); - assert_that!(client.get(keys::user_key(b"k1")).await.unwrap().is_none()).is_true(); - let put_ts = client - .put(keys::user_key(b"k1"), Value::Bytes(b"v1_3".to_vec()), Some(Timestamp::default())) - .await - .unwrap(); + client.put(b"k1", None, Some(put_ts)).await.unwrap(); + assert_that!(client.get(b"k1").await.unwrap().is_none()).is_true(); + let put_ts = client.put(b"k1", Some(Value::Bytes(b"v1_3".to_vec())), Some(Timestamp::default())).await.unwrap(); let (ts, key, value) = client.find(b"k").await.unwrap().unwrap(); assert_that!(ts).is_equal_to(put_ts); @@ -733,15 +1303,15 @@ mod tests { let deployment_watcher = cluster_meta_handle.watch_deployment(None).await.unwrap(); let cluster_env = cluster_env.with_descriptor(descriptor_watcher).with_deployment(deployment_watcher.monitor()); let _node = TabletNode::start(node_id, listener, lease, cluster_env.clone()); - let client = TabletClient::new(cluster_env); + let client = TabletClient::new(cluster_env).scope(keys::USER_KEY_PREFIX); tokio::time::sleep(Duration::from_secs(20)).await; let requests = vec![ - DataRequest::Increment(IncrementRequest { key: keys::user_key(b"count"), increment: 5, sequence: 0 }), - DataRequest::Get(GetRequest { key: keys::user_key(b"count"), sequence: 0 }), - DataRequest::Increment(IncrementRequest { key: keys::user_key(b"count"), increment: 5, sequence: 0 }), + DataRequest::Increment(IncrementRequest { key: b"count".into(), increment: 5, sequence: 0 }), + DataRequest::Get(GetRequest { key: b"count".into(), sequence: 0 }), + DataRequest::Increment(IncrementRequest { key: b"count".into(), increment: 5, sequence: 0 }), DataRequest::Put(PutRequest { - key: keys::user_key(b"k1"), + key: b"k1".into(), value: Some(Value::String("v1_1".to_owned())), expect_ts: None, sequence: 0, @@ -755,9 +1325,9 @@ mod tests { assert_that!(responses.pop().unwrap().into_increment().unwrap().value).is_equal_to(5); let requests = vec![ - DataRequest::Increment(IncrementRequest { key: keys::user_key(b"count"), increment: 5, sequence: 0 }), + DataRequest::Increment(IncrementRequest { key: b"count".into(), increment: 5, sequence: 0 }), DataRequest::Put(PutRequest { - key: keys::user_key(b"k1"), + key: b"k1".into(), value: Some(Value::String("v1_1".to_owned())), expect_ts: Some(Timestamp::ZERO), sequence: 0, @@ -766,9 +1336,9 @@ mod tests { client.batch(requests).await.unwrap_err(); let requests = vec![ - DataRequest::Increment(IncrementRequest { key: keys::user_key(b"count"), increment: 5, sequence: 0 }), + DataRequest::Increment(IncrementRequest { key: b"count".into(), increment: 5, sequence: 0 }), DataRequest::Put(PutRequest { - key: keys::user_key(b"k1"), + key: b"k1".into(), value: Some(Value::String("v1_2".to_owned())), expect_ts: Some(put_ts), sequence: 0, @@ -779,15 +1349,15 @@ mod tests { assert_that!(responses.pop().unwrap().into_increment().unwrap().value).is_equal_to(15); let requests = vec![ - DataRequest::Get(GetRequest { key: keys::user_key(b"k1"), sequence: 0 }), - DataRequest::Find(FindRequest { key: keys::user_key(b"k1"), sequence: 0 }), + DataRequest::Get(GetRequest { key: b"k1".into(), sequence: 0 }), + DataRequest::Find(FindRequest { key: b"k1".into(), sequence: 0 }), ]; let mut responses = client.batch(requests).await.unwrap(); let expect_value = TimestampedValue { value: Value::String("v1_2".to_owned()), timestamp: put_ts }; let find = responses.pop().unwrap().into_find().unwrap(); - assert_that!(find.key).is_equal_to(keys::user_key(b"k1")); + assert_that!(find.key).is_equal_to(b"k1".to_vec()); assert_that!(find.value.unwrap()).is_equal_to(&expect_value); let get = responses.pop().unwrap().into_get().unwrap(); @@ -815,15 +1385,15 @@ mod tests { let deployment_watcher = cluster_meta_handle.watch_deployment(None).await.unwrap(); let cluster_env = cluster_env.with_descriptor(descriptor_watcher).with_deployment(deployment_watcher.monitor()); let _node = TabletNode::start(node_id, listener, lease, cluster_env.clone()); - let client = TabletClient::new(cluster_env); + let client = TabletClient::new(cluster_env).scope(keys::USER_KEY_PREFIX); tokio::time::sleep(Duration::from_secs(20)).await; let requests = vec![ - DataRequest::Find(FindRequest { key: keys::user_key(b"count"), sequence: 0 }), - DataRequest::Increment(IncrementRequest { key: keys::user_key(b"count0"), increment: 5, sequence: 0 }), - DataRequest::Increment(IncrementRequest { key: keys::user_key(b"count1"), increment: 10, sequence: 0 }), - DataRequest::Find(FindRequest { key: keys::user_key(b"count"), sequence: 0 }), - DataRequest::Find(FindRequest { key: keys::user_key(b"count01"), sequence: 0 }), + DataRequest::Find(FindRequest { key: b"count".into(), sequence: 0 }), + DataRequest::Increment(IncrementRequest { key: b"count0".into(), increment: 5, sequence: 0 }), + DataRequest::Increment(IncrementRequest { key: b"count1".into(), increment: 10, sequence: 0 }), + DataRequest::Find(FindRequest { key: b"count".into(), sequence: 0 }), + DataRequest::Find(FindRequest { key: b"count01".into(), sequence: 0 }), ]; let mut responses = client.batch(requests).await.unwrap(); let find = responses.remove(0).into_find().unwrap(); @@ -834,12 +1404,12 @@ mod tests { responses.remove(0); let find0 = responses.remove(0).into_find().unwrap(); - assert_that!(find0.key).is_equal_to(keys::user_key(b"count0")); + assert_that!(find0.key).is_equal_to(b"count0".to_vec()); assert_that!(find0.value).is_some(); assert_that!(find0.value.unwrap().value).is_equal_to(Value::Int(5)); let find1 = responses.remove(0).into_find().unwrap(); - assert_that!(find1.key).is_equal_to(keys::user_key(b"count1")); + assert_that!(find1.key).is_equal_to(b"count1".to_vec()); assert_that!(find1.value.unwrap().value).is_equal_to(Value::Int(10)); } @@ -864,20 +1434,20 @@ mod tests { let deployment_watcher = cluster_meta_handle.watch_deployment(None).await.unwrap(); let cluster_env = cluster_env.with_descriptor(descriptor_watcher).with_deployment(deployment_watcher.monitor()); let _node = TabletNode::start(node_id, listener, lease, cluster_env.clone()); - let client = TabletClient::new(cluster_env); + let client = TabletClient::new(cluster_env).scope(keys::USER_KEY_PREFIX); tokio::time::sleep(Duration::from_secs(20)).await; let requests = vec![ DataRequest::Scan(ScanRequest { - range: KeyRange { start: keys::user_key(b"count"), end: keys::user_key(b"counu") }, + range: KeyRange { start: b"count".into(), end: b"counu".into() }, limit: 0, sequence: 0, }), - DataRequest::Increment(IncrementRequest { key: keys::user_key(b"count0"), increment: 5, sequence: 0 }), - DataRequest::Increment(IncrementRequest { key: keys::user_key(b"count1"), increment: 10, sequence: 0 }), - DataRequest::Increment(IncrementRequest { key: keys::user_key(b"count2"), increment: 15, sequence: 0 }), + DataRequest::Increment(IncrementRequest { key: b"count0".into(), increment: 5, sequence: 0 }), + DataRequest::Increment(IncrementRequest { key: b"count1".into(), increment: 10, sequence: 0 }), + DataRequest::Increment(IncrementRequest { key: b"count2".into(), increment: 15, sequence: 0 }), DataRequest::Scan(ScanRequest { - range: KeyRange { start: keys::user_key(b"count"), end: keys::user_key(b"counu") }, + range: KeyRange { start: b"count".into(), end: b"counu".into() }, limit: 2, sequence: 0, }), @@ -894,9 +1464,9 @@ mod tests { let scan = responses.remove(0).into_scan().unwrap(); debug!("resume key: {:?}", scan.resume_key); assert_that!(scan.rows).has_length(2); - assert_that!(scan.rows[0].key).is_equal_to(keys::user_key(b"count0")); + assert_that!(scan.rows[0].key).is_equal_to(b"count0".to_vec()); assert_that!(scan.rows[0].value).is_equal_to(Value::Int(5)); - assert_that!(scan.rows[1].key).is_equal_to(keys::user_key(b"count1")); + assert_that!(scan.rows[1].key).is_equal_to(b"count1".to_vec()); assert_that!(scan.rows[1].value).is_equal_to(Value::Int(10)); } @@ -1702,7 +2272,7 @@ mod tests { .await .unwrap_err(); - assert_that!(status.message()).contains("fail to refresh key"); + assert_that!(status.message()).contains("conflict with write to key"); } #[test_log::test(tokio::test)] diff --git a/src/tablet/concurrency.rs b/src/tablet/concurrency.rs index f4ff0ab..c3c69fc 100644 --- a/src/tablet/concurrency.rs +++ b/src/tablet/concurrency.rs @@ -35,6 +35,7 @@ use self::fence::FenceTable; use crate::clock::Clock; use crate::keys::Key; use crate::protos::{ + BatchError, BatchRequest, BatchResponse, HasTxnMeta, @@ -63,11 +64,11 @@ pub struct Request { pub temporal: Temporal, pub requests: Vec, - pub responser: oneshot::Sender>, + pub responser: oneshot::Sender>, } impl Request { - pub fn new(request: BatchRequest, responser: oneshot::Sender>) -> Self { + pub fn new(request: BatchRequest, responser: oneshot::Sender>) -> Self { let mut read_keys = vec![]; let mut write_keys = vec![]; let BatchRequest { temporal, requests, .. } = request; @@ -1020,7 +1021,7 @@ impl TxnTable { let participate = match self.txns.participate(txn) { Ok(participate) => participate, Err(err) => { - request.responser.send(Err(err)).ignore(); + request.responser.send(Err(BatchError::with_message(err.to_string()))).ignore(); return None; }, }; diff --git a/src/tablet/concurrency/fence.rs b/src/tablet/concurrency/fence.rs index 495af3c..888fc02 100644 --- a/src/tablet/concurrency/fence.rs +++ b/src/tablet/concurrency/fence.rs @@ -56,7 +56,7 @@ impl Barrier { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Fence { span: KeySpan, barrier: Barrier, @@ -70,14 +70,14 @@ impl Fence { pub fn split(&mut self, pivot: Vec) -> Self { let mut end = pivot; std::mem::swap(&mut self.span.end, &mut end); - let next_span = KeySpan { key: self.span.end.clone(), end }; + let next_span = KeySpan { key: self.span.end().into_owned(), end }; Self { span: next_span, barrier: self.barrier } } } /// Fence table tracks maximum span read timestamp to calculate minimum span write timestamp to /// prevent write-beneath-read. -#[derive(Default, Debug)] +#[derive(Default, Clone, Debug)] pub struct FenceTable { fences: Vec, closed_ts: Timestamp, @@ -168,7 +168,8 @@ impl FenceTable { }, SpanOrdering::ContainLeft => { fence.barrier.update(txn_id, ts); - span.key.clone_from(&fence.span.end); + span.key.clear(); + fence.span.append_end(&mut span.key); i += 1; }, SpanOrdering::ContainAll => { @@ -258,7 +259,7 @@ mod tests { use assertor::*; - use super::FenceTable; + use super::{Fence, FenceTable}; use crate::protos::{KeySpan, Timestamp, Uuid}; #[test] @@ -274,23 +275,23 @@ mod tests { let txn3 = Uuid::new_random(); let mut fences = FenceTable::default(); - assert_that!(fences.min_write_ts(Uuid::nil(), &vec![KeySpan::new_key(b"k1")], Timestamp::ZERO)) + assert_that!(fences.min_write_ts(Uuid::nil(), &[KeySpan::new_key(b"k1")], Timestamp::ZERO)) .is_equal_to(Timestamp::ZERO.next()); - assert_that!(fences.min_write_ts(Uuid::nil(), &vec![KeySpan::new_key(b"k1")], ts10)).is_equal_to(ts10); - assert_that!(fences.min_write_ts(Uuid::new_random(), &vec![KeySpan::new_key(b"k1")], Timestamp::ZERO)) + assert_that!(fences.min_write_ts(Uuid::nil(), &[KeySpan::new_key(b"k1")], ts10)).is_equal_to(ts10); + assert_that!(fences.min_write_ts(Uuid::new_random(), &[KeySpan::new_key(b"k1")], Timestamp::ZERO)) .is_equal_to(Timestamp::ZERO.next()); - assert_that!(fences.min_write_ts(Uuid::new_random(), &vec![KeySpan::new_key(b"k1")], ts10)).is_equal_to(ts10); + assert_that!(fences.min_write_ts(Uuid::new_random(), &[KeySpan::new_key(b"k1")], ts10)).is_equal_to(ts10); // given: close ts50 fences.close_ts(ts50); // then: min write ts advances to ts50.next() - assert_that!(fences.min_write_ts(Uuid::nil(), &vec![KeySpan::new_key(b"k1")], Timestamp::ZERO)) + assert_that!(fences.min_write_ts(Uuid::nil(), &[KeySpan::new_key(b"k1")], Timestamp::ZERO)) .is_equal_to(ts50.next()); - assert_that!(fences.min_write_ts(Uuid::nil(), &vec![KeySpan::new_key(b"k1")], ts10)).is_equal_to(ts50.next()); - assert_that!(fences.min_write_ts(Uuid::new_random(), &vec![KeySpan::new_key(b"k1")], Timestamp::ZERO)) + assert_that!(fences.min_write_ts(Uuid::nil(), &[KeySpan::new_key(b"k1")], ts10)).is_equal_to(ts50.next()); + assert_that!(fences.min_write_ts(Uuid::new_random(), &[KeySpan::new_key(b"k1")], Timestamp::ZERO)) .is_equal_to(ts50.next()); - assert_that!(fences.min_write_ts(Uuid::new_random(), &vec![KeySpan::new_key(b"k1")], ts10)) + assert_that!(fences.min_write_ts(Uuid::new_random(), &[KeySpan::new_key(b"k1")], ts10)) .is_equal_to(ts50.next()); // given: @@ -388,4 +389,41 @@ mod tests { fences.fence(Uuid::nil(), KeySpan::new_range("k19", "k20"), ts70); } + + #[test] + fn contain_all_span_key() { + let mut fences = FenceTable::default(); + let txn_id = Uuid::new_random(); + let ts = Timestamp::ZERO + Duration::from_secs(10); + + fences.fences.push(Fence::new(KeySpan { key: b"abc".to_vec(), end: vec![] }, txn_id, ts)); + fences.fences.push(Fence::new(KeySpan { key: b"abd".to_vec(), end: b"abe".to_vec() }, txn_id, ts)); + fences.fences.push(Fence::new(KeySpan { key: b"abef".to_vec(), end: vec![] }, txn_id, ts)); + + fences.fence(txn_id, KeySpan { key: b"ab".to_vec(), end: b"af".to_vec() }, ts); + } + + #[test] + fn subset_all_span_key() { + let mut fences = FenceTable::default(); + let txn_id = Uuid::new_random(); + let ts = Timestamp::ZERO + Duration::from_secs(10); + + fences.fences.push(Fence::new(KeySpan { key: b"abc".to_vec(), end: b"abf".to_vec() }, txn_id, ts)); + fences.fence(txn_id, KeySpan { key: b"abd".to_vec(), end: vec![] }, ts + Duration::from_secs(1)); + } + + #[test] + fn contain_left_span_key() { + use std::str::FromStr; + let mut fences = FenceTable::default(); + let txn_id = uuid::Uuid::from_str("5f8c2d8b-965c-455a-8e97-4392c59b8839").unwrap().into(); + let ts = Timestamp { seconds: 1733404865, nanoseconds: 50223000, logical: 0 }; + + fences.fences.push(Fence::new(KeySpan { key: b"abc".to_vec(), end: vec![] }, txn_id, ts)); + fences.fences.push(Fence::new(KeySpan { key: b"abd".to_vec(), end: b"abf".to_vec() }, txn_id, ts)); + fences.fences.push(Fence::new(KeySpan { key: b"abfg".to_vec(), end: vec![] }, txn_id, ts)); + + fences.fence(txn_id, KeySpan { key: b"abc".to_vec(), end: b"abcd".to_vec() }, ts); + } } diff --git a/src/tablet/deployer.rs b/src/tablet/deployer.rs index d94ba0b..67717a8 100644 --- a/src/tablet/deployer.rs +++ b/src/tablet/deployer.rs @@ -130,10 +130,10 @@ pub trait TabletDeployer { let nodes = self.nodes().clone(); let deployment_receiver = deployment_watcher.clone(); let crash_reporter = crash_reporter.clone(); - let deployment_span = span!(Level::INFO, "cluster deployment deployment", %node, %addr); + let deployment_span = span!(Level::INFO, "cluster deployment", %node, %addr); tokio::spawn(async move { if let Err(err) = nodes.start_deployment(&node, addr, deployment_receiver).await { - tracing::info!("deployment deployment terminated: {}", err); + tracing::info!("deployment terminated: {}", err); } crash_reporter.send(node).ignore(); }.instrument(deployment_span)); @@ -384,7 +384,6 @@ impl RangeTabletDeployer { receiver } - #[instrument(skip(self), fields(self.shard = %self.shard_id, self.tablet = %self.tablet_id))] async fn serve_internally(&self) -> Result<()> { let mut deployments = self.load_deployments(); let mut load_completed = false; @@ -402,7 +401,7 @@ impl RangeTabletDeployer { Ok(()) } - #[instrument(skip(self), fields(self.shard = %self.shard_id, self.tablet = %self.tablet_id))] + #[instrument(skip(self), fields(shard = %self.shard_id, tablet = %self.tablet_id))] async fn serve(&self) { if let Err(err) = self.serve_internally().await { debug!("fail to serve range tablet deployer: {err}"); diff --git a/src/tablet/loader.rs b/src/tablet/loader.rs index d84f4e2..79d6176 100644 --- a/src/tablet/loader.rs +++ b/src/tablet/loader.rs @@ -17,14 +17,16 @@ use std::sync::Arc; use anyhow::{anyhow, bail, Result}; use async_trait::async_trait; -use derivative::Derivative; +use derive_where::derive_where; use tokio::select; +use tracing::trace; use super::store::{BatchContext, BatchResult, TabletStore, TxnTabletStore}; use super::types::{MessageId, TabletWatermark}; use crate::clock::{Clock, Timestamp}; use crate::log::{ByteLogProducer, ByteLogSubscriber, LogAddress, LogManager, LogOffset, LogPosition}; use crate::protos::{ + BatchError, BatchRequest, BatchResponse, DataMessage, @@ -85,11 +87,11 @@ pub trait LogMessageProducer: ByteLogProducer { } } -#[derive(Derivative)] -#[derivative(Debug)] +#[allow(clippy::needless_lifetimes)] +#[derive_where(Debug)] pub struct TypedLogConsumer<'a, T> { consumer: &'a mut dyn ByteLogSubscriber, - #[derivative(Debug = "ignore")] + #[derive_where(skip(Debug))] _marker: std::marker::PhantomData, } @@ -101,7 +103,7 @@ impl<'a, T> TypedLogConsumer<'a, T> { #[async_trait] impl ByteLogSubscriber for TypedLogConsumer<'_, T> { - async fn read(&mut self) -> Result<(LogPosition, &[u8])> { + async fn read<'a>(&'a mut self) -> Result<(LogPosition, &'a [u8])> { self.consumer.read().await } @@ -148,20 +150,19 @@ impl<'a, T> LimitedLogConsumer<'a, T> { } } -#[derive(Derivative)] -#[derivative(Debug)] +#[derive_where(Debug)] pub struct MessageConsumer { last_recv: LogPosition, last_epoch: u64, last_sequence: u64, consumer: Box, - #[derivative(Debug = "ignore")] + #[derive_where(skip(Debug))] _marker: std::marker::PhantomData, } #[async_trait] impl ByteLogSubscriber for MessageConsumer { - async fn read(&mut self) -> Result<(LogPosition, &[u8])> { + async fn read<'a>(&'a mut self) -> Result<(LogPosition, &'a [u8])> { let (offset, _, payload) = (self as &mut MessageConsumer).read().await?; Ok((offset, payload)) } @@ -183,12 +184,11 @@ impl LogMessageConsumer for MessageConsumer { } } -#[derive(Derivative)] -#[derivative(Debug)] +#[derive_where(Debug)] pub struct BufMessageProducer { buf: Vec, producer: Box, - #[derivative(Debug = "ignore")] + #[derive_where(skip(Debug))] _marker: std::marker::PhantomData, } @@ -309,7 +309,7 @@ impl MessageProducer { if position.is_next_of(&self.last_sent) { return Ok(position); } - tracing::debug!("interleaving sending detected: last sent position {}, but got next {}", self.last_sent, position); + trace!("interleaving sending detected: last sent position {}, but got next {}", self.last_sent, position); }, } } @@ -465,7 +465,7 @@ impl LeadingTablet { let watermark = &self.manifest.manifest.watermark; if ts > watermark.leader_expiration { return Ok(Some(BatchResult::Error { - error: anyhow!("write above leader expiration"), + error: BatchError::with_message("write above leader expiration"), responser: request.responser, })); } @@ -516,12 +516,16 @@ impl FollowingTablet { } } - fn query(&mut self, ts: Timestamp, requests: Vec) -> Result> { + fn query(&mut self, ts: Timestamp, requests: Vec) -> Result, BatchError> { let mut context = BatchContext::default(); self.store.store.batch_timestamped(&mut context, ts, requests) } - pub fn query_batch(&mut self, deployment: &TabletDeployment, batch: BatchRequest) -> Result { + pub fn query_batch( + &mut self, + deployment: &TabletDeployment, + batch: BatchRequest, + ) -> Result { let mut response = BatchResponse::default(); if let Some(ts) = self.extract_read_timestamp(&batch) { response.responses = self.query(ts, batch.requests)?; @@ -577,10 +581,7 @@ impl TabletLoader { (_, None) => continue, }, Ok(None) => bail!("no manifest message"), - Err(err) => { - tracing::warn!("XXXX: {:?}", err); - return Err(err); - }, + Err(err) => return Err(err), } }; while let Some(message) = subscriber.read().await? { @@ -696,6 +697,7 @@ impl TabletLoader { return Err(anyhow!("do not support file compaction and transaction rotation for now")); } let uri = tablet.data_log.as_str().try_into()?; + trace!("subscribing log {uri}"); let mut consumer: Box = self.log.subscribe_log(&uri, LogOffset::Earliest).await?; let limit = match limit { None => consumer.latest().await?, diff --git a/src/tablet/server.rs b/src/tablet/server.rs index 6187106..b9f23b7 100644 --- a/src/tablet/server.rs +++ b/src/tablet/server.rs @@ -73,6 +73,17 @@ impl TabletServiceImpl { ) -> Result { self.request(request, receiver).await?.map_err(|e| Status::invalid_argument(e.to_string())) } + + async fn request_message_result( + &self, + request: TabletServiceRequest, + receiver: oneshot::Receiver>, + ) -> Result { + match self.request(request, receiver).await? { + Ok(response) => Ok(response), + Err(err) => Err(Status::with_details(tonic::Code::Internal, err.to_string(), err.encode_to_vec().into())), + } + } } #[async_trait] @@ -104,7 +115,10 @@ impl TabletService for TabletServiceImpl { async fn batch(&self, request: Request) -> Result, Status> { let (sender, receiver) = oneshot::channel(); let response = self - .request_result(TabletServiceRequest::Batch { batch: request.into_inner(), responser: sender }, receiver) + .request_message_result( + TabletServiceRequest::Batch { batch: request.into_inner(), responser: sender }, + receiver, + ) .await?; Ok(Response::new(response)) } diff --git a/src/tablet/service.rs b/src/tablet/service.rs index 478f47a..e4984a1 100644 --- a/src/tablet/service.rs +++ b/src/tablet/service.rs @@ -30,6 +30,7 @@ use crate::clock::Clock; use crate::cluster::{ClusterEnv, NodeId}; use crate::protos::{ self, + BatchError, BatchRequest, BatchResponse, DataMessage, @@ -62,7 +63,7 @@ use crate::utils::{self, DropOwner}; struct BatchResponser { temporal: Temporal, responses: Vec, - responser: oneshot::Sender>, + responser: oneshot::Sender>, } impl BatchResponser { @@ -245,7 +246,7 @@ impl TabletServiceState { Ok(BatchResponse { temporal, responses, deployments: Default::default() }) } else { assert!(stage == ReplicationStage::Failed); - Err(anyhow!("replication failed")) + Err(BatchError::with_message("replication failed".to_string())) }; responser.send(result).ignore(); }) @@ -304,6 +305,7 @@ impl TabletServiceState { } }, Some(mut txn) = tablet.store.store.updated_txns().recv() => { + trace!("update txn {:?}", txn); let (replication, requests) = tablet.store.store.update_txn(&mut txn); trace!("unblock txn {}(epoch:{}, {:?}) requests {:?}", txn.id(), txn.epoch(), txn.status(), requests); unblocking_requests.extend(requests.into_iter()); @@ -496,14 +498,14 @@ impl TabletServiceManager { responser.send(Ok(())).ignore(); } - fn apply_batch(&mut self, batch: BatchRequest, responser: oneshot::Sender>) { + fn apply_batch(&mut self, batch: BatchRequest, responser: oneshot::Sender>) { if let Some((_, requester)) = self.tablets.get(&batch.tablet_id) { let request = TabletRequest::Batch { batch, responser }; if let Err(mpsc::error::SendError(TabletRequest::Batch { batch, responser })) = requester.send(request) { - responser.send(Err(anyhow!("tablet {} closed", batch.tablet_id))).ignore(); + responser.send(Err(BatchError::with_message(format!("tablet {} closed", batch.tablet_id)))).ignore(); } } else { - responser.send(Err(anyhow!("tablet {} not found", batch.tablet_id))).ignore(); + responser.send(Err(BatchError::with_message(format!("tablet {} not found", batch.tablet_id)))).ignore(); } } diff --git a/src/tablet/store.rs b/src/tablet/store.rs index 79743a2..48da410 100644 --- a/src/tablet/store.rs +++ b/src/tablet/store.rs @@ -17,7 +17,7 @@ use std::collections::btree_map::{BTreeMap, Entry as BTreeEntry}; use std::collections::VecDeque; use std::ops::Bound; -use anyhow::{anyhow, bail, Result}; +use anyhow::Result; use hashbrown::hash_map::{Entry as HashEntry, HashMap}; use ignore_result::Ignore; use tokio::sync::{mpsc, oneshot}; @@ -38,7 +38,9 @@ use crate::clock::Clock; use crate::keys::Key; use crate::protos::{ self, + BatchError, BatchResponse, + DataError, DataOperation, DataRequest, DataResponse, @@ -122,19 +124,34 @@ impl Default for Writes { } } +impl Writes { + pub fn is_empty(&self) -> bool { + match self { + Self::Write(_) => false, + Self::Batch(batch) => batch.is_empty(), + } + } +} + +impl From for DataError { + fn from(err: anyhow::Error) -> Self { + DataError::internal(err.to_string()) + } +} + #[derive(Debug)] pub enum BatchResult { Read { temporal: Temporal, responses: Vec, - responser: oneshot::Sender>, + responser: oneshot::Sender>, blocker: ReplicationWatcher, }, Write { temporal: Temporal, responses: Vec, - responser: oneshot::Sender>, + responser: oneshot::Sender>, writes: Vec, replication: ReplicationTracker, @@ -142,8 +159,8 @@ pub enum BatchResult { requests: Vec, }, Error { - error: anyhow::Error, - responser: oneshot::Sender>, + error: BatchError, + responser: oneshot::Sender>, }, } @@ -177,8 +194,13 @@ impl ShardStore { self.store.get(key, ts) } - pub fn get_timestamped(&self, context: &BatchContext, ts: Timestamp, key: &[u8]) -> Result { - self.get(context, ts, key).map(|value| value.into_client()) + pub fn get_timestamped( + &self, + context: &BatchContext, + ts: Timestamp, + key: &[u8], + ) -> Result { + self.get(context, ts, key).map(|value| value.into_client()).map_err(|e| DataError::internal(e.to_string())) } pub fn find(&self, context: &BatchContext, ts: Timestamp, key: &[u8]) -> Result<(Vec, TimestampedValue)> { @@ -210,24 +232,16 @@ impl ShardStore { limit: u32, ) -> Result<(Vec, Vec)> { let mut resume_key = range.start; - let end_key = if range.end <= self.range.end { - range.end - } else { - let mut end = range.end; - end.clone_from(&self.range.end); - end - }; - trace!("end key: {end_key:?}"); let mut rows = vec![]; while limit == 0 || rows.len() < limit as usize { trace!("find key: {resume_key:?}"); let (key, value) = self.find_timestamped(context, ts, &resume_key)?; trace!("found: {key:?}, {value:?}"); if key.is_empty() { - resume_key.clone_from(&self.range.end); + resume_key = self.range.resume_from(range.end); break; - } else if key >= end_key { - resume_key = end_key; + } else if key >= range.end { + resume_key = Vec::default(); break; } resume_key.clone_from(&key); @@ -238,11 +252,16 @@ impl ShardStore { Ok((resume_key, rows)) } - fn check_timestamped_write(&self, context: &BatchContext, ts: Timestamp, key: &[u8]) -> Result { + fn check_timestamped_write( + &self, + context: &BatchContext, + ts: Timestamp, + key: &[u8], + ) -> Result { let value = self.get(context, ts, key)?; // FIXME: allow equal timestamp for now to gain write-your-write. if ts < value.timestamp { - return Err(anyhow!("write@{} encounters newer timestamp {}", ts, value.timestamp)); + return Err(DataError::conflict_write(key.to_vec(), value.timestamp)); } Ok(value.into_client()) } @@ -268,10 +287,10 @@ impl ShardStore { key: &[u8], value: Option, expect_ts: Option, - ) -> Result { + ) -> Result { let existing_value = self.check_timestamped_write(context, ts, key)?; - if let Some(expect_ts) = expect_ts.filter(|ts| *ts != existing_value.timestamp) { - bail!("mismatch timestamp check: existing ts {}, expect ts {:}", existing_value.timestamp, expect_ts) + if expect_ts.filter(|ts| *ts != existing_value.timestamp).is_some() { + return Err(DataError::timestamp_mismatch(key, existing_value.timestamp)); } self.add_timestamped_write(context, ts, key.to_owned(), value); Ok(ts) @@ -283,7 +302,7 @@ impl ShardStore { ts: Timestamp, key: &[u8], increment: i64, - ) -> Result { + ) -> Result { let value = self.check_timestamped_write(context, ts, key)?; let i = value.value.read_int(key, "increment")?; let incremented = i + increment; @@ -298,41 +317,30 @@ impl ShardStore { ts: Timestamp, span: KeySpan, from: Timestamp, - ) -> Result<()> { + ) -> Result, DataError> { if span.end.is_empty() { let found = self.get(context, ts, &span.key)?; if found.timestamp <= from { - return Ok(()); + return Ok(Vec::default()); } - bail!( - "fail to refresh key {:?} from {} to {} due to write at timestamp {:?}", - span.key, - from, - ts, - found.timestamp - ); + return Err(DataError::conflict_write(span.key, found.timestamp)); } let mut start = span.key; - let end = span.end; loop { let (key, value) = self.find(context, ts, &start)?; - if key.is_empty() || key >= end { + if key.is_empty() { + return Ok(self.range.resume_from(span.end)); + } else if key >= span.end { break; } if value.timestamp > from { - bail!( - "fail to refresh key {:?} from {:?} to {:?} due to write at timestamp {:?}", - key, - from, - ts, - value.timestamp - ); + return Err(DataError::conflict_write(key, value.timestamp)); } start.clear(); start.extend(&key); start.push(0); } - Ok(()) + Ok(Vec::default()) } fn add_transactional_write( @@ -373,7 +381,7 @@ impl ShardStore { txn: &TxnRecord, key: &[u8], sequence: u32, - ) -> Result { + ) -> Result { if let Some(intent) = provision.get_intent(context, key) { assert!(intent.txn.is_same(txn)); if let Some(value) = intent.get_value(sequence) { @@ -391,7 +399,7 @@ impl ShardStore { txn: &TxnRecord, key: &[u8], sequence: u32, - ) -> Result<(Vec, TimestampedValue)> { + ) -> Result<(Vec, TimestampedValue), DataError> { let (found, value) = self.find_timestamped(context, txn.commit_ts(), key)?; let mut start = key.to_owned(); let end = found.as_slice(); @@ -416,21 +424,14 @@ impl ShardStore { sequence: u32, ) -> Result<(Vec, Vec)> { let mut resume_key = range.start; - let end_key = if range.end <= self.range.end { - range.end - } else { - let mut end = range.end; - end.clone_from(&self.range.end); - end - }; let mut rows = vec![]; while limit == 0 || rows.len() < limit as usize { let (key, value) = self.find_transactional(context, provision, txn, &resume_key, sequence)?; if key.is_empty() { - resume_key.clone_from(&self.range.end); + resume_key = self.range.resume_from(range.end); break; - } else if key >= end_key { - resume_key = end_key; + } else if key >= range.end { + resume_key = Vec::default(); break; } resume_key.clone_from(&key); @@ -461,10 +462,13 @@ impl ShardStore { value: ValueExtractor, sequence: u32, existing_ts: Option, - ) -> Result { + ) -> Result { if sequence <= intent.value.sequence { let Some(intent_value) = intent.get_value(sequence) else { - bail!("txn {:?} key {:?}: can't find written value for old sequence {}", intent.txn, key, sequence); + return Err(DataError::internal(format!( + "txn {:?} key {:?}: can't find written value for old sequence {}", + intent.txn, key, sequence + ))); }; let writting = match value { ValueExtractor::Plain(value) => value, @@ -475,14 +479,10 @@ impl ShardStore { }; let written = intent_value.value.value; if writting != written { - bail!( + return Err(DataError::internal(format!( "txn {:?} key {:?}: non idempotent write at sequence {}, old write {:?}, new write {:?}", - intent.txn, - key, - sequence, - written, - writting - ); + intent.txn, key, sequence, written, writting + ))); } return Ok(Timestamp::txn_sequence(sequence)); } @@ -493,13 +493,7 @@ impl ShardStore { }; if let Some(expected_ts) = existing_ts { if expected_ts != existing_value.client_ts() { - bail!( - "txn {:?} key {:?}: expect value at {:?}, but got value at {:?}", - intent.txn.meta, - key, - expected_ts, - existing_value.client_ts() - ) + return Err(DataError::timestamp_mismatch(key, existing_value.client_ts())); } } let value = match value { @@ -520,7 +514,7 @@ impl ShardStore { value: ValueExtractor, sequence: u32, existing_ts: Option, - ) -> Result { + ) -> Result { if let Some(intent) = provision.get_intent(context, &key) { assert!(intent.txn.is_same(txn)); return self.rewrite_txn_intent(context, intent.clone(), key, value, sequence, existing_ts); @@ -528,17 +522,12 @@ impl ShardStore { let commit_ts = txn.commit_ts(); let found = self.get(context, Timestamp::MAX, &key)?; if commit_ts <= found.timestamp { - return Err(anyhow!( - "key {:?}: try to write txn at {:?}, but got value at {:?}", - key, - commit_ts, - found.timestamp - )); + return Err(DataError::conflict_write(key, found.timestamp)); } if let Some(expected_ts) = existing_ts { let client_ts = found.client_ts(); if expected_ts != client_ts { - return Err(anyhow!("key {:?}: expect timestamp at {:?}, but got {:?}", key, expected_ts, client_ts)); + return Err(DataError::timestamp_mismatch(key, client_ts)); } } let value = match value { @@ -559,7 +548,7 @@ impl ShardStore { value: Option, sequence: u32, expect_ts: Option, - ) -> Result { + ) -> Result { self.put_transactional_with_value_extractor( context, provision, @@ -600,29 +589,24 @@ impl ShardStore { txn: &TxnRecord, span: KeySpan, from: Timestamp, - ) -> Result<()> { + ) -> Result, DataError> { if span.end.is_empty() { if let Some(intent) = provision.get_intent(context, &span.key) { if intent.txn.is_same(txn) || intent.txn.commit_ts() > txn.commit_ts() { - return Ok(()); + return Ok(Vec::default()); } - bail!("txn {:?} fail to refresh key {:?} due to write from txn {:?}", txn.meta, span.key, intent.txn); + return Err(DataError::conflict_write(span.key, intent.txn.get().clone())); } let found = self.get(context, txn.commit_ts(), &span.key)?; if found.timestamp <= from { - return Ok(()); + return Ok(Vec::default()); } - bail!( - "txn {:?} fail to refresh key {:?} due to write at timestamp {:?}", - txn.meta, - span.key, - found.timestamp - ); + return Err(DataError::conflict_write(span.key, found.timestamp)); } let mut start = span.key.clone(); while let Some((found, intent)) = provision.find_intent(context, &start, &span.end) { if !intent.txn.is_same(txn) && intent.txn.commit_ts() <= txn.commit_ts() { - bail!("txn {:?} fail to refresh key {:?} due to write from txn {:?}", txn.meta, span.key, intent.txn); + return Err(DataError::conflict_write(span.key, intent.txn.get().clone())); } start.clear(); start.extend(found); @@ -632,22 +616,118 @@ impl ShardStore { start.extend(&span.key); loop { let (key, value) = self.find(context, txn.commit_ts(), &start)?; - if key.is_empty() || key >= span.end { + if key.is_empty() { + return Ok(self.range.resume_from(span.end)); + } else if key >= span.end { break; } if value.timestamp > from { - bail!( - "txn {:?} fail to refresh span {:?} due to write at timestamp {:?}", - txn.meta, - span, - value.timestamp - ); + return Err(DataError::conflict_write(key, value.timestamp)); } start.clear(); start.extend(&key); start.push(0); } - Ok(()) + Ok(Vec::default()) + } + + fn handle_transactional_request( + &mut self, + context: &mut BatchContext, + provision: &mut TxnProvision, + txn: &TxnRecord, + request: DataRequest, + ) -> Result { + Ok(match request { + DataRequest::Get(get) => { + let value = self.get_transactional(context, provision, txn, &get.key, get.sequence)?; + context.reads.watch(&value.value); + let response = GetResponse { value: value.into() }; + DataResponse::Get(response) + }, + DataRequest::Find(find) => { + let (key, value) = self.find_transactional(context, provision, txn, &find.key, find.sequence)?; + context.reads.watch(&value.value); + let response = FindResponse { key, value: value.into() }; + DataResponse::Find(response) + }, + DataRequest::Scan(scan) => { + let (resume_key, rows) = + self.scan_transactional(context, provision, txn, scan.range, scan.limit, scan.sequence)?; + rows.iter().for_each(|row| { + context.reads.watch(&row.value.value); + }); + let response = ScanResponse { resume_key, rows: rows.into_iter().map(Into::into).collect() }; + DataResponse::Scan(response) + }, + DataRequest::Put(put) => { + let ts = + self.put_transactional(context, provision, txn, put.key, put.value, put.sequence, put.expect_ts)?; + let response = PutResponse { write_ts: ts }; + DataResponse::Put(response) + }, + DataRequest::Increment(increment) => { + let incremented = self.increment_transactional( + context, + provision, + txn, + increment.key, + increment.increment, + increment.sequence, + )?; + let response = IncrementResponse { value: incremented }; + DataResponse::Increment(response) + }, + DataRequest::RefreshRead(refresh) => { + let resume_key = + self.refresh_read_transactional(context, provision, txn, refresh.span, refresh.from)?; + DataResponse::RefreshRead(RefreshReadResponse { resume_key }) + }, + }) + } + + fn handle_timestamped_request( + &mut self, + context: &mut BatchContext, + ts: Timestamp, + request: DataRequest, + ) -> Result { + Ok(match request { + DataRequest::Get(get) => { + let value = self.get_timestamped(context, ts, &get.key)?; + context.reads.watch(&value.value); + let response = GetResponse { value: value.into() }; + DataResponse::Get(response) + }, + DataRequest::Find(find) => { + let (key, value) = self.find_timestamped(context, ts, &find.key)?; + context.reads.watch(&value.value); + let response = FindResponse { key, value: value.into() }; + DataResponse::Find(response) + }, + DataRequest::Scan(scan) => { + let (resume_key, rows) = self.scan_timestamped(context, ts, scan.range, scan.limit)?; + rows.iter().for_each(|row| { + context.reads.watch(&row.value.value); + }); + let response = ScanResponse { resume_key, rows: rows.into_iter().map(Into::into).collect() }; + DataResponse::Scan(response) + }, + DataRequest::Put(put) => { + let ts = self.put_timestamped(context, ts, &put.key, put.value, put.expect_ts)?; + let response = PutResponse { write_ts: ts }; + DataResponse::Put(response) + }, + DataRequest::Increment(increment) => { + let incremented = self.increment_timestamped(context, ts, &increment.key, increment.increment)?; + let response = IncrementResponse { value: incremented }; + DataResponse::Increment(response) + }, + DataRequest::RefreshRead(refresh_read) => { + let resume_key = self.refresh_read_timestamped(context, ts, refresh_read.span, refresh_read.from)?; + DataResponse::RefreshRead(RefreshReadResponse { resume_key }) + }, + }) } } @@ -689,18 +769,19 @@ impl DataStore { self.stores.iter_mut().find(|shard| shard.range.contains(key)) } - fn find_shard_store_mut(&mut self, id: ShardId, key: &[u8]) -> Result<&mut ShardStore> { + fn find_shard_store_mut(&mut self, id: ShardId, key: &[u8]) -> Result<&mut ShardStore, DataError> { if let Some(store) = self.get_shard_store_mut(id) { return Ok(unsafe { std::mem::transmute(store) }); } - self.locate_shard_store_mut(key).ok_or_else(|| anyhow!("shard {id} not found for key {key:?}")) + self.locate_shard_store_mut(key).ok_or_else(|| DataError::shard_not_found(key, id)) } - fn put(&mut self, key: Vec, ts: Timestamp, value: Value) -> Result<()> { + fn put(&mut self, key: Vec, ts: Timestamp, value: Value) -> Result<(), DataError> { let Some(store) = self.locate_shard_store_mut(&key) else { - bail!("key {:?} does not reside in tablet {} with shards {:?}", key, self.id, self.shards) + return Err(DataError::shard_not_found(key, ShardId(0))); }; - store.put(key, ts, value) + store.put(key, ts, value)?; + Ok(()) } fn put_if_located(&mut self, key: Vec, ts: Timestamp, value: Value) -> Result<()> { @@ -708,7 +789,7 @@ impl DataStore { store.put(key, ts, value) } - fn promote(&mut self, context: &mut BatchContext) -> Result<()> { + fn promote(&mut self, context: &mut BatchContext) -> Result<(), DataError> { for (key, values) in context.cache.timestamped.take().into_iter() { let (ts, value) = values.into_iter().next_back().unwrap(); self.put(key, ts, value)?; @@ -732,14 +813,27 @@ impl TxnProvision { return entry.get().clone(); } let outdated_txn = entry.remove(); - self.remove_intents(&outdated_txn.write_set); + self.remove_intents(&outdated_txn.take_write_set()); TxnRecord::new(txn.clone()) } fn add_txn_writes(&mut self, txn: TxnRecord, writes: Writes) { + if writes.is_empty() { + return; + } + let txn = self.transactions.entry(txn.id()).or_insert(txn); for write in writes { + txn.add_write_span(KeySpan::new_key(write.key.clone())); match self.intents.entry(write.key) { BTreeEntry::Occupied(mut entry) => { + trace!( + "txn {} write key {:?} to intent {} with value {:?}, sequence {}", + txn.meta(), + entry.key(), + entry.get().meta(), + write.value, + write.sequence + ); let intent = entry.get_mut(); intent.push_replicated(write.value, write.sequence); }, @@ -751,45 +845,12 @@ impl TxnProvision { } } - fn apply_txn(&mut self, store: &mut DataStore, txn: TxnRecord, writes: Writes) { - match txn.status { - TxnStatus::Pending => self.add_txn_writes(txn, writes), - TxnStatus::Aborted => { - let id = txn.id(); - let Some(existing_txn) = self.transactions.remove(&id) else { - return; - }; - for span in &existing_txn.write_set { - self.intents.remove(&span.key); - } - if !txn.commit_set.is_empty() { - assert!(txn.write_set.is_empty()); - self.transactions.insert(id, txn); - } - }, - TxnStatus::Committed => { - self.add_txn_writes(txn.clone(), writes); - let id = txn.id(); - let commit_ts = txn.commit_ts(); - let Some(existing_txn) = self.transactions.remove(&id) else { - return; - }; - for span in &existing_txn.write_set { - let Some((key, intent)) = self.intents.remove_entry(&span.key) else { - continue; - }; - let Some(latest) = intent.into_latest(&txn) else { - continue; - }; - let value = Value::new_replicated(latest.into_value()); - store.put(key, commit_ts, value).unwrap(); - } - if !txn.commit_set.is_empty() { - assert!(txn.write_set.is_empty()); - self.transactions.insert(id, txn); - } - }, - } + fn apply_txn(&mut self, store: &mut DataStore, txn: Transaction, writes: Writes) { + let record = self.prepare_txn(&txn); + self.add_txn_writes(record, writes); + let mut replication = ReplicationTracker::default(); + self.resolve(store, &mut replication, &txn, true); + replication.commit(); } fn remove_intents(&mut self, spans: &[KeySpan]) { @@ -815,6 +876,8 @@ impl TxnProvision { ) -> Option<(&'a [u8], &'a TxnIntent)> { let bounds = if end.is_empty() { (Bound::Included(key.to_owned()), Bound::Unbounded) + } else if key >= end { + return None; } else { (Bound::Included(key.to_owned()), Bound::Included(end.to_owned())) }; @@ -844,7 +907,7 @@ impl TxnProvision { if txn.status == TxnStatus::Pending { if txn.epoch() > entry.get().epoch() { let record = entry.remove(); - self.remove_intents(&record.write_set); + self.remove_intents(&record.take_write_set()); } return; } @@ -864,6 +927,7 @@ impl TxnProvision { } else { current.update(txn); } + trace!("resolving {} write set {:?}", txn.meta(), write_set.iter().take(10)); if txn.status == TxnStatus::Aborted { for span in write_set { self.intents.remove(&span.key); @@ -949,6 +1013,7 @@ impl TabletStore { } pub fn apply(&mut self, mut message: protos::DataMessage) -> Result<()> { + trace!("apply {message:?}"); let cursor = MessageId::new(message.epoch, message.sequence); self.update_cursor(cursor); if let (Some(closed_timestamp), Some(leader_expiration)) = @@ -968,7 +1033,6 @@ impl TabletStore { Temporal::Transaction(txn) => txn, }; - let txn = self.prepare_txn(&txn); self.apply_txn(txn, message.operation.take().into()); Ok(()) } @@ -992,11 +1056,7 @@ impl TabletStore { } } - fn prepare_txn(&mut self, txn: &Transaction) -> TxnRecord { - self.provision.prepare_txn(txn) - } - - fn apply_txn(&mut self, txn: TxnRecord, writes: Writes) { + fn apply_txn(&mut self, txn: Transaction, writes: Writes) { self.provision.apply_txn(&mut self.store, txn, writes) } @@ -1009,157 +1069,76 @@ impl TabletStore { } } + fn handle_timestamped_request( + &mut self, + context: &mut BatchContext, + ts: Timestamp, + request: ShardRequest, + ) -> Result { + let ShardRequest { shard_id, request } = request; + let shard_store = self.store.find_shard_store_mut(shard_id.into(), request.key())?; + let response = shard_store.handle_timestamped_request(context, ts, request)?; + let shard = if shard_id == shard_store.id.into_raw() { + None + } else { + Some(ShardDescriptor { id: shard_id, range: shard_store.range.clone(), tablet_id: self.store.id.into() }) + }; + Ok(ShardResponse { response, shard }) + } + pub fn batch_timestamped( &mut self, context: &mut BatchContext, ts: Timestamp, requests: Vec, - ) -> Result> { + ) -> Result, BatchError> { let mut responses = Vec::with_capacity(requests.len()); - for ShardRequest { shard_id, request } in requests.into_iter() { - let key = request.key(); - let shard_store = self.store.find_shard_store_mut(shard_id.into(), key)?; - let response = match request { - DataRequest::Get(get) => { - let value = shard_store.get_timestamped(context, ts, &get.key)?; - context.reads.watch(&value.value); - let response = GetResponse { value: value.into() }; - DataResponse::Get(response) - }, - DataRequest::Find(find) => { - let (key, value) = shard_store.find_timestamped(context, ts, &find.key)?; - context.reads.watch(&value.value); - let response = FindResponse { key, value: value.into() }; - DataResponse::Find(response) - }, - DataRequest::Scan(scan) => { - let (resume_key, rows) = shard_store.scan_timestamped(context, ts, scan.range, scan.limit)?; - rows.iter().for_each(|row| { - context.reads.watch(&row.value.value); - }); - let response = ScanResponse { resume_key, rows: rows.into_iter().map(Into::into).collect() }; - DataResponse::Scan(response) - }, - DataRequest::Put(put) => { - let ts = shard_store.put_timestamped(context, ts, &put.key, put.value, put.expect_ts)?; - let response = PutResponse { write_ts: ts }; - DataResponse::Put(response) - }, - DataRequest::Increment(increment) => { - let incremented = - shard_store.increment_timestamped(context, ts, &increment.key, increment.increment)?; - let response = IncrementResponse { value: incremented }; - DataResponse::Increment(response) - }, - DataRequest::RefreshRead(refresh_read) => { - shard_store.refresh_read_timestamped(context, ts, refresh_read.span, refresh_read.from)?; - DataResponse::RefreshRead(RefreshReadResponse {}) - }, - }; - let shard = if shard_id == shard_store.id.into_raw() { - None - } else { - Some(ShardDescriptor { - id: shard_id, - range: shard_store.range.clone(), - tablet_id: self.store.id.into(), - }) + for (i, request) in requests.into_iter().enumerate() { + let response = match self.handle_timestamped_request(context, ts, request) { + Ok(response) => response, + Err(err) => return Err(BatchError::with_index(i, err)), }; - responses.push(ShardResponse { response, shard }); + responses.push(response); + } + if let Err(err) = self.store.promote(context) { + return Err(BatchError::new(err)); } - self.store.promote(context)?; Ok(responses) } + fn handle_transactional_request( + &mut self, + context: &mut BatchContext, + txn: &TxnRecord, + request: ShardRequest, + ) -> Result { + let ShardRequest { shard_id, request } = request; + let shard_store = self.store.find_shard_store_mut(shard_id.into(), request.key())?; + let response = shard_store.handle_transactional_request(context, &mut self.provision, txn, request)?; + let shard = if shard_id == shard_store.id.into_raw() { + None + } else { + Some(ShardDescriptor { id: shard_id, range: shard_store.range.clone(), tablet_id: self.store.id.into() }) + }; + Ok(ShardResponse { response, shard }) + } + pub fn batch_transactional( &mut self, context: &mut BatchContext, txn: &Transaction, requests: Vec, - ) -> Result> { + ) -> Result, BatchError> { let txn = self.provision.prepare_txn(txn); let mut responses = Vec::with_capacity(requests.len()); - for ShardRequest { shard_id, request } in requests.into_iter() { - let key = request.key(); - let shard_store = self.store.find_shard_store_mut(shard_id.into(), key)?; - let response = match request { - DataRequest::Get(get) => { - let value = - shard_store.get_transactional(context, &self.provision, &txn, &get.key, get.sequence)?; - context.reads.watch(&value.value); - let response = GetResponse { value: value.into() }; - DataResponse::Get(response) - }, - DataRequest::Find(find) => { - let (key, value) = - shard_store.find_transactional(context, &self.provision, &txn, &find.key, find.sequence)?; - context.reads.watch(&value.value); - let response = FindResponse { key, value: value.into() }; - DataResponse::Find(response) - }, - DataRequest::Scan(scan) => { - let (resume_key, rows) = shard_store.scan_transactional( - context, - &self.provision, - &txn, - scan.range, - scan.limit, - scan.sequence, - )?; - rows.iter().for_each(|row| { - context.reads.watch(&row.value.value); - }); - let response = ScanResponse { resume_key, rows: rows.into_iter().map(Into::into).collect() }; - DataResponse::Scan(response) - }, - DataRequest::Put(put) => { - let ts = shard_store.put_transactional( - context, - &mut self.provision, - &txn, - put.key, - put.value, - put.sequence, - put.expect_ts, - )?; - let response = PutResponse { write_ts: ts }; - DataResponse::Put(response) - }, - DataRequest::Increment(increment) => { - let incremented = shard_store.increment_transactional( - context, - &mut self.provision, - &txn, - increment.key, - increment.increment, - increment.sequence, - )?; - let response = IncrementResponse { value: incremented }; - DataResponse::Increment(response) - }, - DataRequest::RefreshRead(refresh) => { - shard_store.refresh_read_transactional( - context, - &mut self.provision, - &txn, - refresh.span, - refresh.from, - )?; - DataResponse::RefreshRead(RefreshReadResponse {}) - }, - }; - let shard = if shard_id == shard_store.id.into_raw() { - None - } else { - Some(ShardDescriptor { - id: shard_id, - range: shard_store.range.clone(), - tablet_id: self.store.id.into(), - }) + for (i, request) in requests.into_iter().enumerate() { + let response = match self.handle_transactional_request(context, &txn, request) { + Ok(response) => response, + Err(err) => return Err(BatchError::with_index(i, err)), }; - responses.push(ShardResponse { response, shard }); + responses.push(response); } - self.provision.promote(context)?; + self.provision.promote(context).unwrap(); Ok(responses) } @@ -1168,7 +1147,7 @@ impl TabletStore { context: &mut BatchContext, temporal: &Temporal, requests: Vec, - ) -> Result> { + ) -> Result, BatchError> { match temporal { Temporal::Timestamp(ts) => self.batch_timestamped(context, *ts, requests), Temporal::Transaction(txn) => self.batch_transactional(context, txn, requests), diff --git a/src/tablet/types.rs b/src/tablet/types.rs index e743de8..167dfc5 100644 --- a/src/tablet/types.rs +++ b/src/tablet/types.rs @@ -21,6 +21,7 @@ use tokio::sync::mpsc::UnboundedSender; use tokio::sync::oneshot; use crate::protos::{ + BatchError, BatchRequest, BatchResponse, ParticipateTxnRequest, @@ -97,7 +98,7 @@ impl StreamingRequester { pub enum TabletRequest { Batch { batch: BatchRequest, - responser: oneshot::Sender>, + responser: oneshot::Sender>, }, Deploy { epoch: u64, @@ -125,7 +126,7 @@ pub enum TabletServiceRequest { }, Batch { batch: BatchRequest, - responser: oneshot::Sender>, + responser: oneshot::Sender>, }, UnloadTablet { deployment: TabletDeployment, diff --git a/src/txn.rs b/src/txn.rs index 695e343..9287a9e 100644 --- a/src/txn.rs +++ b/src/txn.rs @@ -13,73 +13,120 @@ // limitations under the License. use std::borrow::Cow; +use std::cmp::Ordering::*; use std::collections::{HashMap, HashSet}; -use std::pin::pin; +use std::pin::{pin, Pin}; +use std::sync::{Arc, RwLock}; use anyhow::anyhow; use asyncs::select; -use asyncs::sync::watch; +use asyncs::sync::{Notified, Notify}; use asyncs::task::TaskHandle; -use thiserror::Error; +use ignore_result::Ignore; +use lazy_init::Lazy; use tracing::trace; +use crate::kv::{KvClient, KvError, KvSemantics}; use crate::protos::{ - BatchRequest, - DataRequest, HasTxnMeta, HasTxnStatus, KeySpan, - RefreshReadRequest, - ShardRequest, Temporal, Timestamp, + TimestampedKeyValue, Transaction, - TxnMeta, TxnStatus, - Uuid, Value, }; -use crate::tablet::{TabletClient, TabletClientError}; +use crate::tablet::TabletClient; use crate::timer::Timer; -struct WritingTxn { +#[derive(Clone, Debug)] +struct TimestampKeySpan { + ts: Timestamp, + span: KeySpan, +} + +struct TxnState { txn: Transaction, + scan_set: Vec, read_set: HashMap, Timestamp>, write_set: HashSet>, epoch: u32, sequence: u32, } -impl WritingTxn { +impl TxnState { pub fn new(txn: Transaction) -> Self { - Self { txn, read_set: Default::default(), write_set: Default::default(), epoch: 0, sequence: 0 } + Self { + txn, + scan_set: Default::default(), + read_set: Default::default(), + write_set: Default::default(), + epoch: 0, + sequence: 0, + } } fn check_txn_status(&self) -> Result<()> { match self.txn.status { - TxnStatus::Aborted => Err(TxnError::TxnAborted { txn_id: self.txn.id(), epoch: self.txn.epoch() }), - TxnStatus::Committed => Err(TxnError::TxnCommitted { txn_id: self.txn.id(), epoch: self.txn.epoch() }), + TxnStatus::Aborted => Err(KvError::TxnAborted { txn_id: self.txn.id(), epoch: self.txn.epoch() }), + TxnStatus::Committed => Err(KvError::TxnCommitted { txn_id: self.txn.id(), epoch: self.txn.epoch() }), TxnStatus::Pending => Ok(()), } } - fn check_txn(&self, epoch: u32) -> Result<()> { + fn update(&mut self, txn: &Transaction) -> Result<()> { + self.txn.update(txn); + self.check()?; + Ok(()) + } + + fn update_read(&mut self, txn: &Transaction, key: &[u8]) -> Result<()> { + self.txn.update(txn); + self.check()?; + self.add_read(key); + Ok(()) + } + + fn update_scan(&mut self, txn: &Transaction, start: &[u8], end: &[u8]) -> Result<()> { + if end.is_empty() { + return self.update_read(txn, start); + } + self.txn.update(txn); + self.check()?; + self.add_scan(start, end); + Ok(()) + } + + fn update_write(&mut self, txn: &Transaction, key: &[u8]) -> Result<()> { + self.txn.update(txn); + self.check()?; + self.add_write(key); + Ok(()) + } + + fn check(&self) -> Result<()> { self.check_txn_status()?; - if self.txn.epoch() > epoch { - return Err(TxnError::TxnRestarted { + if self.txn.epoch() > self.epoch { + return Err(KvError::TxnRestarted { txn_id: self.txn.id(), - from_epoch: epoch, + from_epoch: self.epoch, to_epoch: self.txn.epoch(), }); } Ok(()) } - fn update_and_check(&mut self, txn: &Transaction) -> Result<()> { - let epoch = self.txn.epoch(); - self.txn.update(txn); - self.check_txn(epoch)?; - Ok(()) + fn for_write(&mut self) -> Result<(u32, Transaction)> { + self.check()?; + self.sequence += 1; + Ok((self.sequence, self.txn.clone())) + } + + fn for_read(&self) -> Result<(u32, Transaction)> { + self.check()?; + Ok((self.sequence, self.txn.clone())) } pub fn add_read<'a>(&mut self, key: impl Into>) { @@ -89,6 +136,73 @@ impl WritingTxn { } } + pub fn add_scan(&mut self, start: &[u8], end: &[u8]) { + self.scan_set.push(TimestampKeySpan { ts: self.txn.commit_ts(), span: KeySpan::new_range(start, end) }); + self.scan_set.sort_by(|a, b| a.span.cmp(&b.span)); + let mut i = 1; + while i < self.scan_set.len() { + let previous = unsafe { &mut *self.scan_set.as_mut_ptr().wrapping_add(i - 1) }; + let current = unsafe { &mut *self.scan_set.as_mut_ptr().wrapping_add(i) }; + if previous.span.end <= current.span.key { + i += 1; + continue; + } + let mut current = self.scan_set.remove(i); + match ( + previous.span.key == current.span.key, + previous.span.end.cmp(¤t.span.end), + previous.ts.cmp(¤t.ts), + ) { + (true, Equal, Less) => previous.ts = current.ts, + (true, Equal, _) => {}, + (true, Less, Less | Equal) => { + previous.ts = current.ts; + previous.span.end = current.span.end; + }, + (true, Less, Greater) => { + current.span.key.clone_from(&previous.span.end); + self.scan_set.insert(i, current); + i += 1; + }, + (true, Greater, Equal | Greater) => {}, + (true, Greater, Less) => { + std::mem::swap(&mut previous.ts, &mut current.ts); + std::mem::swap(&mut previous.span.end, &mut current.span.end); + current.span.key.clone_from(&previous.span.end); + self.scan_set.insert(i, current); + i += 1; + }, + (false, Equal, Less) => { + previous.span.end.clone_from(¤t.span.key); + self.scan_set.insert(i, current); + i += 1; + }, + (false, Equal, _) => {}, + (false, Less, Equal | Greater) => { + current.span.key.clone_from(&previous.span.end); + self.scan_set.insert(i, current); + i += 1; + }, + (false, Less, Less) => { + previous.span.end.clone_from(¤t.span.key); + self.scan_set.insert(i, current); + i += 1; + }, + (false, Greater, Less) => { + let next = TimestampKeySpan { + ts: previous.ts, + span: KeySpan { key: current.span.end.clone(), end: std::mem::take(&mut previous.span.end) }, + }; + previous.span.end = current.span.key.clone(); + self.scan_set.insert(i, current); + self.scan_set.insert(i + 1, next); + i += 2; + }, + (false, Greater, Equal | Greater) => {}, + } + } + } + pub fn add_write(&mut self, key: &[u8]) { match self.read_set.remove_entry(key) { Some((key, _value)) => self.write_set.insert(key), @@ -96,37 +210,14 @@ impl WritingTxn { }; } - pub fn txn(&self) -> &Transaction { - &self.txn - } - - pub fn committing_txn(&self) -> Transaction { - let mut txn = self.txn.clone(); - txn.status = TxnStatus::Committed; - txn.commit_set = self.write_set.iter().map(|key| KeySpan::new_key(key.to_owned())).collect(); - txn - } - - pub fn aborting_txn(&self) -> Transaction { + pub fn for_abort(&self) -> Transaction { let mut txn = self.txn.clone(); txn.abort(); txn.commit_set = self.write_set.iter().map(|key| KeySpan::new_key(key.to_owned())).collect(); txn } - fn read_sequence(&self) -> u32 { - self.sequence - } - - fn write_sequence(&self) -> u32 { - self.sequence + 1 - } - - fn bump_sequence(&mut self) { - self.sequence += 1; - } - - pub fn restart(&mut self) { + pub fn restart(&mut self, commit_ts: Timestamp) { if self.epoch == self.txn.epoch() { self.txn.restart(); } @@ -134,258 +225,398 @@ impl WritingTxn { self.sequence = 0; self.read_set.clear(); self.write_set.clear(); + self.txn.commit_ts = commit_ts; } - pub fn outdated_reads(&self) -> Vec<(Vec, Timestamp)> { + pub fn outdated_reads(&self) -> Vec { let commit_ts = self.txn.commit_ts(); - self.read_set.iter().filter(|(_key, ts)| **ts < commit_ts).map(|(key, ts)| (key.to_owned(), *ts)).collect() + let mut read_spans = + self.scan_set.iter().filter(|TimestampKeySpan { ts, .. }| *ts < commit_ts).cloned().collect::>(); + + self.read_set.iter().filter(|(_key, ts)| **ts < commit_ts).for_each(|(key, ts)| { + read_spans.push(TimestampKeySpan { span: KeySpan::new_key(key.clone()), ts: *ts }); + }); + read_spans + } + + pub fn read_refreshes_for_commit(&self) -> Result<(Transaction, Vec)> { + self.check()?; + let mut txn = self.txn.clone(); + let outdated_reads = self.outdated_reads(); + if outdated_reads.is_empty() { + txn.status = TxnStatus::Committed; + txn.commit_set = self.write_set.iter().map(|key| KeySpan::new_key(key.to_owned())).collect(); + } + Ok((txn, outdated_reads)) } } -impl HasTxnMeta for WritingTxn { - fn meta(&self) -> &TxnMeta { - self.txn.meta() +struct TrackingTxn { + state: RwLock, + notify: Notify, +} + +impl TrackingTxn { + pub fn new(txn: Transaction) -> Self { + Self { notify: Notify::new(), state: RwLock::new(TxnState::new(txn)) } + } + + fn update(&self, txn: &Transaction) -> Result<()> { + let mut state = self.state.write().unwrap(); + state.update(txn)?; + self.notify.notify_all(); + Ok(()) + } + + fn update_read(&self, txn: &Transaction, key: &[u8]) -> Result<()> { + let mut state = self.state.write().unwrap(); + state.update_read(txn, key)?; + self.notify.notify_all(); + Ok(()) + } + + fn update_scan(&self, txn: &Transaction, start: &[u8], end: &[u8]) -> Result<()> { + let mut state = self.state.write().unwrap(); + state.update_scan(txn, start, end)?; + self.notify.notify_all(); + Ok(()) + } + + fn update_write(&self, txn: &Transaction, key: &[u8]) -> Result<()> { + let mut state = self.state.write().unwrap(); + state.update_write(txn, key)?; + self.notify.notify_all(); + Ok(()) + } + + fn check_notified(&self) -> Result> { + let state = self.state.read().unwrap(); + state.check()?; + let mut notified = Box::new(self.notify.notified()); + Pin::new(notified.as_mut()).enable(); + Ok(notified) + } + + fn write(&self) -> Result<(u32, Transaction, Box)> { + let mut state = self.state.write().unwrap(); + let (sequence, txn) = state.for_write()?; + let mut notified = Box::new(self.notify.notified()); + Pin::new(notified.as_mut()).enable(); + Ok((sequence, txn.clone(), notified)) + } + + fn read(&self) -> Result<(u32, Transaction, Box)> { + let state = self.state.read().unwrap(); + let (sequence, txn) = state.for_read()?; + let mut notified = Box::new(self.notify.notified()); + Pin::new(notified.as_mut()).enable(); + Ok((sequence, txn.clone(), notified)) + } + + fn restart(&self, commit_ts: Timestamp) { + self.state.write().unwrap().restart(commit_ts); + } + + pub fn bump_commit_ts(&self, commit_ts: Timestamp) { + self.state.write().unwrap().txn.commit_ts = commit_ts; + } + + pub fn read_refreshes_for_commit(&self) -> Result<(Transaction, Vec)> { + self.state.read().unwrap().read_refreshes_for_commit() + } + + pub fn for_abort(&self) -> Transaction { + self.state.read().unwrap().for_abort() + } + + pub fn txn(&self) -> Transaction { + self.state.read().unwrap().txn.clone() + } + + pub fn commit_ts(&self) -> Timestamp { + self.state.read().unwrap().txn.commit_ts() } } -pub struct Txn { - txn: WritingTxn, +pub struct LazyInitTxn { client: TabletClient, - heartbeating_txn: watch::Receiver, - heartbeating_task: Option>, + txn: Lazy, } -#[derive(Debug, Error)] -pub enum TxnError { - #[error("{0}")] - ClientError(#[from] TabletClientError), - #[error("txn {txn_id} restarted from epoch {from_epoch} to {to_epoch}")] - TxnRestarted { txn_id: Uuid, from_epoch: u32, to_epoch: u32 }, - #[error("txn {txn_id} aborted in epoch {epoch}")] - TxnAborted { txn_id: Uuid, epoch: u32 }, - #[error("txn {txn_id} already committed with epoch {epoch}")] - TxnCommitted { txn_id: Uuid, epoch: u32 }, - #[error(transparent)] - Internal(#[from] anyhow::Error), +impl LazyInitTxn { + pub fn new(client: TabletClient) -> Self { + Self { client, txn: Lazy::new() } + } + + fn txn(&self, key: &[u8]) -> &Txn { + self.txn.get_or_create(|| Txn::new(self.client.clone(), key.to_owned())) + } } -impl From for TxnError { - fn from(status: tonic::Status) -> Self { - Self::from(TabletClientError::from(status)) +#[async_trait::async_trait] +impl KvClient for LazyInitTxn { + fn client(&self) -> &TabletClient { + &self.client } + + fn semantics(&self) -> KvSemantics { + KvSemantics::Transactional + } + + async fn get(&self, key: &[u8]) -> Result> { + self.txn(key).get(key).await + } + + async fn scan(&self, start: &[u8], end: &[u8], limit: u32) -> Result<(Vec, Vec)> { + self.txn(start).scan(start, end, limit).await + } + + async fn put(&self, key: &[u8], value: Option, expect_ts: Option) -> Result { + self.txn(key).put(key, value, expect_ts).await + } + + async fn increment(&self, key: &[u8], increment: i64) -> Result { + self.txn(key).increment(key, increment).await + } + + async fn commit(&self) -> Result { + self.txn.get().unwrap().commit().await + } + + async fn abort(&self) -> Result<()> { + self.txn.get().unwrap().abort().await + } + + fn restart(&self) -> Result<()> { + self.txn.get().unwrap().restart() + } +} + +pub struct Txn { + txn: Arc, + client: TabletClient, + #[allow(unused)] + heartbeating_task: Option>, } -type Result = std::result::Result; +type Result = std::result::Result; impl Txn { pub fn new(client: TabletClient, key: impl Into>) -> Self { let txn = client.new_transaction(key.into()); - let (heartbeating_txn, heartbeating_task) = Self::start_heartbeat_task(client.clone(), txn.clone()); - Self { txn: WritingTxn::new(txn), client, heartbeating_txn, heartbeating_task: Some(heartbeating_task) } + let (txn, heartbeating_task) = Self::start_heartbeat_task(client.clone(), txn.clone()); + Self { txn, client, heartbeating_task: Some(heartbeating_task) } } - async fn heartbeat_once(client: &TabletClient, txn: Transaction) -> Result { - let (shard, mut service) = client.service(txn.key()).await?; - let response = service - .batch(BatchRequest { - tablet_id: shard.tablet_id().into(), - temporal: Temporal::Transaction(txn), - ..Default::default() - }) - .await? - .into_inner(); - Ok(response.temporal.into_transaction()) + pub fn client(&self) -> &TabletClient { + &self.client } - async fn heartbeat(client: TabletClient, mut txn: Transaction, sender: watch::Sender) { + async fn heartbeat(client: TabletClient, tracking_txn: Arc, mut txn: Transaction) { let mut timer = Timer::after(Transaction::HEARTBEAT_INTERVAL / 2); loop { timer.await; - match Self::heartbeat_once(&client, txn.clone()).await { - Err(_) => {}, - Ok(updated_txn) => txn.update(&updated_txn), - } - loop { - if let Ok(updated_txn) = Self::heartbeat_once(&client, txn.clone()).await { + if let Ok(updated_txn) = client.heartbeat_txn(txn.clone()).await { txn.update(&updated_txn); break; } Timer::after(Transaction::HEARTBEAT_INTERVAL / 8).await; } - if sender.send(txn.clone()).is_err() || txn.status().is_terminal() { + tracking_txn.update(&txn).ignore(); + if txn.status().is_terminal() { break; } timer = Timer::after(Transaction::HEARTBEAT_INTERVAL / 2); } } - fn start_heartbeat_task(client: TabletClient, txn: Transaction) -> (watch::Receiver, TaskHandle<()>) { - let (sender, receiver) = watch::channel(txn.clone()); - let task = asyncs::spawn(Self::heartbeat(client, txn, sender)).attach(); - (receiver, task) + fn start_heartbeat_task(client: TabletClient, txn: Transaction) -> (Arc, TaskHandle<()>) { + let tracking_txn = Arc::new(TrackingTxn::new(txn.clone())); + let task = asyncs::spawn(Self::heartbeat(client, tracking_txn.clone(), txn)).attach(); + (tracking_txn, task) + } + + pub fn bump_commit_ts(&self) { + self.txn.bump_commit_ts(self.client.now()); } - fn sync_heartbeating_txn(&mut self) -> Result<()> { - let txn = self.heartbeating_txn.borrow_and_update(); - if txn.has_changed() { - self.txn.update_and_check(&txn)?; + async fn refresh_reads(&self) -> Result { + loop { + let (txn, mut outdated_reads) = self.txn.read_refreshes_for_commit()?; + if outdated_reads.is_empty() { + return Ok(txn); + } + let commit_ts = txn.commit_ts(); + while let Some(TimestampKeySpan { span, ts }) = outdated_reads.pop() { + if span.end.is_empty() { + trace!("refreshing read for key {:?} from {ts} to {commit_ts}", span.key); + } else { + trace!("refreshing read for span {:?} from {ts} to {commit_ts}", span); + } + let (temporal, response) = self.client.refresh_read(txn.clone().into(), span.clone(), ts).await?; + let txn = temporal.into_transaction(); + match response { + None => self.txn.update(&txn)?, + Some(resume_key) if resume_key.is_empty() => self.txn.update_scan(&txn, &span.key, &span.end)?, + Some(resume_key) => self.txn.update_scan(&txn, &span.key, &resume_key)?, + } + } } - Ok(()) } - pub async fn get(&mut self, key: &[u8]) -> Result> { - self.sync_heartbeating_txn()?; - let mut get = self.client.transactional_get(self.txn.txn().clone(), key, self.txn.read_sequence()); + pub fn txn(&self) -> Transaction { + self.txn.txn() + } + + pub fn commit_ts(&self) -> Timestamp { + self.txn.commit_ts() + } +} + +#[async_trait::async_trait] +impl KvClient for Txn { + fn client(&self) -> &TabletClient { + &self.client + } + + fn semantics(&self) -> KvSemantics { + KvSemantics::Transactional + } + + async fn get(&self, key: &[u8]) -> Result> { + let (sequence, txn, mut notified) = self.txn.read()?; + let mut get = self.client.get_directly(txn.into(), key, sequence); let mut get = pin!(get); loop { select! { - r = self.heartbeating_txn.changed() => match r { - Err(_) => return Err(TxnError::Internal(anyhow!("txn {} heartbeat stopped", self.txn.id()))), - Ok(txn) => { - self.txn.update_and_check(&txn)?; - } - }, - r = get.as_mut() => match r { - Err(err) => return Err(TxnError::from(err)), - Ok((txn, value)) => { - self.txn.update_and_check(&txn)?; - self.txn.add_read(key); - return Ok(value.map(|v| (v.timestamp, v.value))); - } + _ = notified => notified = self.txn.check_notified()?, + r = get.as_mut() => { + let (temporal, response) = r?; + let txn = match temporal { + Temporal::Transaction(txn) => txn, + Temporal::Timestamp(ts) => return Err(KvError::unexpected(format!("txn get received timestamp piggybacked temporal {ts}"))), + }; + self.txn.update_read(&txn, key.as_ref())?; + return match response { + None => Err(KvError::unexpected("txn get receives no response")), + Some(response) => Ok(response.value.map(|v| v.into_parts())), + }; } } } } - pub async fn put(&mut self, key: &[u8], value: Option, expect_ts: Option) -> Result<()> { - self.sync_heartbeating_txn()?; - let mut put = - self.client.transactional_put(self.txn.txn().clone(), key, value, self.txn.write_sequence(), expect_ts); - let mut put = pin!(put); + async fn scan(&self, start: &[u8], end: &[u8], limit: u32) -> Result<(Vec, Vec)> { + let (_sequence, txn, mut notified) = self.txn.read()?; + let mut scan = self.client.scan_directly(txn.into(), start, end, limit); + let mut scan = pin!(scan); loop { select! { - r = self.heartbeating_txn.changed() => match r { - Err(_) => return Err(TxnError::Internal(anyhow!("txn {} heartbeat stopped", self.txn.id()))), - Ok(txn) => { - self.txn.update_and_check(&txn)?; - } - }, - r = put.as_mut() => match r { - Err(err) => return Err(TxnError::from(err)), - Ok(txn) => { - self.txn.update_and_check(&txn)?; - self.txn.bump_sequence(); - self.txn.add_write(key); - return Ok(()); - } + _ = notified => notified = self.txn.check_notified()?, + r = scan.as_mut() => { + let (temporal, response) = r?; + let txn = match temporal { + Temporal::Transaction(txn) => txn, + Temporal::Timestamp(ts) => return Err(KvError::unexpected(format!("txn scan received timestamp piggybacked temporal {ts}"))), + }; + return match response { + None => { + self.txn.update(&txn)?; + Err(KvError::unexpected("txn scan receives no response")) + }, + Some((resume_key, rows)) => { + if resume_key.is_empty() || resume_key.as_slice() >= end { + self.txn.update_scan(&txn, start, end)?; + } else { + self.txn.update_scan(&txn, start, &resume_key)?; + } + Ok((resume_key, rows)) + } + }; } } } } - pub async fn increment(&mut self, key: &[u8], increment: i64) -> Result { - self.sync_heartbeating_txn()?; - let mut increment = - self.client.transactional_increment(self.txn.txn().clone(), key, increment, self.txn.write_sequence()); - let mut increment = pin!(increment); + async fn put(&self, key: &[u8], value: Option, expect_ts: Option) -> Result { + let (sequence, txn, mut notified) = self.txn.write()?; + let mut put = self.client.put_directly(txn.into(), key, value, expect_ts, sequence); + let mut put = pin!(put); loop { select! { - r = self.heartbeating_txn.changed() => match r { - Err(_) => return Err(TxnError::Internal(anyhow!("txn {} heartbeat stopped", self.txn.id()))), - Ok(txn) => { - self.txn.update_and_check(&txn)?; - } - }, - r = increment.as_mut() => match r { - Err(err) => return Err(TxnError::from(err)), - Ok((txn, incremented)) => { - self.txn.update_and_check(&txn)?; - self.txn.bump_sequence(); - self.txn.add_write(key); - return Ok(incremented); - } + _ = notified => notified = self.txn.check_notified()?, + r = put.as_mut() => { + let (temporal, response) = r?; + let txn = match temporal { + Temporal::Transaction(txn) => txn, + Temporal::Timestamp(ts) => return Err(KvError::unexpected(format!("txn put received timestamp piggybacked temporal {ts}"))), + }; + self.txn.update_write(&txn, key)?; + return match response { + None => return Err(KvError::unexpected("txn put receives no response")), + Some(ts) => Ok(ts), + }; } } } } - pub fn restart(&mut self) { - self.txn.restart(); - self.bump_commit_ts(); - } - - pub fn bump_commit_ts(&mut self) { - self.txn.txn.commit_ts = self.client.now(); - } - - async fn refresh_reads(&mut self) -> Result<()> { - self.sync_heartbeating_txn()?; + async fn increment(&self, key: &[u8], increment: i64) -> Result { + let (sequence, txn, mut notified) = self.txn.write()?; + let mut increment = self.client.increment_directly(txn.into(), key, increment, sequence); + let mut increment = pin!(increment); loop { - let commit_ts = self.txn.txn().commit_ts(); - let mut outdated_reads = self.txn.outdated_reads(); - while let Some((key, ts)) = outdated_reads.pop() { - let (shard, mut service) = self.client.service(&key).await?; - trace!("refreshing read for key {key:?} from {ts} to {commit_ts}"); - let requests = vec![ShardRequest { - shard_id: shard.shard_id().into(), - request: DataRequest::RefreshRead(RefreshReadRequest { - span: KeySpan::new_key(key.clone()), - from: ts, - }), - }]; - let response = service - .batch(BatchRequest { - tablet_id: shard.tablet_id().into(), - temporal: Temporal::Transaction(self.txn.txn().clone()), - requests, - ..Default::default() - }) - .await? - .into_inner(); - let txn = response.temporal.into_transaction(); - self.txn.update_and_check(&txn)?; - self.txn.add_read(key); - } - self.sync_heartbeating_txn()?; - if self.txn.txn().commit_ts() <= commit_ts { - break; + select! { + _ = notified => notified = self.txn.check_notified()?, + r = increment.as_mut() => { + let (temporal, response) = r?; + let txn = match temporal { + Temporal::Transaction(txn) => txn, + Temporal::Timestamp(ts) => return Err(KvError::unexpected(format!("txn increment received timestamp piggybacked temporal {ts}"))), + }; + self.txn.update_write(&txn, key.as_ref())?; + return match response { + None => return Err(KvError::unexpected("txn put receives no response")), + Some(incremented) => Ok(incremented), + }; + } } } - Ok(()) } - pub async fn commit(&mut self) -> Result<()> { + async fn commit(&self) -> Result { loop { - self.refresh_reads().await?; - let txn = Self::heartbeat_once(&self.client, self.txn.committing_txn()).await?; - match self.txn.update_and_check(&txn) { - Err(TxnError::TxnCommitted { .. }) => break, + let txn = self.refresh_reads().await?; + let txn = self.client.heartbeat_txn(txn).await?; + match self.txn.update(&txn) { + Err(KvError::TxnCommitted { .. }) => break, Err(err) => return Err(err), Ok(()) => {}, } } - self.heartbeating_task = None; - Ok(()) + Ok(self.commit_ts()) } - pub async fn abort(&mut self) -> Result<()> { - let txn = Self::heartbeat_once(&self.client, self.txn.aborting_txn()).await?; - match self.txn.update_and_check(&txn) { - Ok(()) => Err(TxnError::Internal(anyhow!("txn not aborted: {:?}", self.txn.txn()))), - Err(TxnError::TxnAborted { .. }) => Ok(()), + async fn abort(&self) -> Result<()> { + let txn = self.client.heartbeat_txn(self.txn.for_abort()).await?; + match self.txn.update(&txn) { + Ok(()) => Err(KvError::Internal(anyhow!("txn not aborted: {:?}", self.txn()))), + Err(KvError::TxnAborted { .. }) => Ok(()), Err(err) => Err(err), } } - pub fn txn(&self) -> &Transaction { - self.txn.txn() + fn restart(&self) -> Result<()> { + self.txn.restart(self.client.now()); + Ok(()) } } #[cfg(test)] mod tests { + use std::borrow::Cow; use std::time::Duration; use assertor::*; @@ -396,6 +627,7 @@ mod tests { use crate::cluster::{ClusterEnv, EtcdClusterMetaDaemon, EtcdNodeRegistry, NodeId}; use crate::endpoint::{Endpoint, Params}; use crate::keys; + use crate::kv::KvClient; use crate::log::{LogManager, MemoryLogFactory}; use crate::protos::{TxnStatus, Value}; use crate::tablet::{TabletClient, TabletNode}; @@ -424,9 +656,9 @@ mod tests { let client = TabletClient::new(cluster_env); tokio::time::sleep(Duration::from_secs(20)).await; - client.put(keys::user_key(b"counter1"), Value::Int(1), None).await.unwrap(); + client.put(keys::user_key(b"counter1"), Some(Value::Int(1)), None).await.unwrap(); - let mut txn = Txn::new(client.clone(), keys::user_key(b"counter1")); + let txn = Txn::new(client.clone(), keys::user_key(b"counter1")); txn.put(&keys::system_key(b"counter0"), Some(Value::Int(10)), None).await.unwrap(); txn.put(&keys::user_key(b"counter1"), None, None).await.unwrap(); txn.increment(&keys::user_key(b"counter2"), 100).await.unwrap(); @@ -465,7 +697,7 @@ mod tests { tokio::time::sleep(Duration::from_secs(20)).await; let counter_key = keys::user_key(b"counter"); - let mut txn = Txn::new(client.clone(), counter_key.clone()); + let txn = Txn::new(client.clone(), counter_key.clone()); assert!(client.get(&counter_key).await.unwrap().is_none()); @@ -506,7 +738,7 @@ mod tests { tokio::time::sleep(Duration::from_secs(20)).await; let counter_key = keys::user_key(b"counter"); - let mut txn = Txn::new(client.clone(), counter_key.clone()); + let txn = Txn::new(client.clone(), counter_key.clone()); let tablet_id_counter_key = keys::system_key(b"tablet-id-counter"); @@ -516,7 +748,7 @@ mod tests { txn.bump_commit_ts(); txn.put(&counter_key, Some(tablet_id_value), None).await.unwrap(); - txn.commit().await.unwrap_err(); + eprintln!("EEEE: {:?}", txn.commit().await.unwrap_err()); // Same to above except no write-to-read. let (_ts, tablet_id_value) = txn.get(&tablet_id_counter_key).await.unwrap().unwrap(); @@ -530,6 +762,58 @@ mod tests { assert_that!(value).is_equal_to(tablet_id_value); } + #[test_log::test(tokio::test)] + #[tracing_test::traced_test] + async fn txn_read_refresh_scan() { + let etcd = etcd_container(); + let cluster_uri = etcd.uri().with_path("/team1/seamdb1").unwrap(); + + let node_id = NodeId::new_random(); + let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap(); + let address = format!("http://{}", listener.local_addr().unwrap()); + let endpoint = Endpoint::try_from(address.as_str()).unwrap(); + let (nodes, lease) = + EtcdNodeRegistry::join(cluster_uri.clone(), node_id.clone(), Some(endpoint.to_owned())).await.unwrap(); + let log_manager = + LogManager::new(MemoryLogFactory::new(), &MemoryLogFactory::ENDPOINT, &Params::default()).await.unwrap(); + let cluster_env = ClusterEnv::new(log_manager.into(), nodes).with_replicas(1); + let mut cluster_meta_handle = + EtcdClusterMetaDaemon::start("seamdb1", cluster_uri.clone(), cluster_env.clone()).await.unwrap(); + let descriptor_watcher = cluster_meta_handle.watch_descriptor(None).await.unwrap(); + let deployment_watcher = cluster_meta_handle.watch_deployment(None).await.unwrap(); + let cluster_env = cluster_env.with_descriptor(descriptor_watcher).with_deployment(deployment_watcher.monitor()); + let _node = TabletNode::start(node_id, listener, lease, cluster_env.clone()); + let client = TabletClient::new(cluster_env).scope(keys::USER_KEY_PREFIX); + tokio::time::sleep(Duration::from_secs(20)).await; + + let txn = Txn::new(client.clone(), b"counter".to_vec()); + let mut write = false; + let sum = 'sum: loop { + let mut sum = 0; + let mut start = Cow::Borrowed(b"counter".as_slice()); + while !start.is_empty() { + let (resume_key, rows) = txn.scan(start.as_ref(), b"countes", 0).await.unwrap(); + for row in rows { + if let Value::Int(v) = row.value { + sum += v; + } + } + start = Cow::Owned(resume_key); + if !write { + client.increment(b"counter0", 10).await.unwrap(); + client.increment(b"counter1", 100).await.unwrap(); + write = true; + } + txn.bump_commit_ts(); + match txn.commit().await { + Ok(_) => break 'sum sum, + Err(_) => continue, + } + } + }; + assert_eq!(sum, 110); + } + #[test_log::test(tokio::test)] #[tracing_test::traced_test] async fn txn_commit_push_forward() { @@ -551,13 +835,13 @@ mod tests { let deployment_watcher = cluster_meta_handle.watch_deployment(None).await.unwrap(); let cluster_env = cluster_env.with_descriptor(descriptor_watcher).with_deployment(deployment_watcher.monitor()); let _node = TabletNode::start(node_id, listener, lease, cluster_env.clone()); - let client = TabletClient::new(cluster_env); + let client = TabletClient::new(cluster_env).scope(keys::USER_KEY_PREFIX); tokio::time::sleep(Duration::from_secs(20)).await; - let key = keys::user_key(b"counter"); - let mut txn = Txn::new(client.clone(), key.clone()); + let key = b"counter"; + let txn = Txn::new(client.clone(), key.to_vec()); - txn.put(&key, Some(Value::Int(1)), None).await.unwrap(); + txn.put(key, Some(Value::Int(1)), None).await.unwrap(); let commit_ts = txn.txn().commit_ts(); tokio::time::sleep(Duration::from_secs(30)).await; @@ -592,7 +876,7 @@ mod tests { tokio::time::sleep(Duration::from_secs(20)).await; let key = keys::user_key(b"counter"); - let mut txn = Txn::new(client.clone(), key.clone()); + let txn = Txn::new(client.clone(), key.clone()); txn.put(&key, Some(Value::Int(1)), None).await.unwrap(); txn.abort().await.unwrap();