Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gzip + none compression algos and let SDK pick compression #1802

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 10 additions & 2 deletions crates/bench/benches/subscription.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use spacetimedb::db::relational_db::RelationalDB;
use spacetimedb::error::DBError;
use spacetimedb::execution_context::ExecutionContext;
use spacetimedb::host::module_host::DatabaseTableUpdate;
use spacetimedb::identity::AuthCtx;
use spacetimedb::messages::websocket::BsatnFormat;
use spacetimedb::subscription::query::compile_read_only_query;
use spacetimedb::subscription::subscription::ExecutionSet;
use spacetimedb::{db::relational_db::RelationalDB, messages::websocket::Compression};
use spacetimedb_bench::database::BenchDatabase as _;
use spacetimedb_bench::spacetime_raw::SpacetimeRaw;
use spacetimedb_primitives::{col_list, TableId};
Expand Down Expand Up @@ -104,7 +104,15 @@ fn eval(c: &mut Criterion) {
let query = compile_read_only_query(&raw.db, &AuthCtx::for_testing(), &tx, sql).unwrap();
let query: ExecutionSet = query.into();
let ctx = &ExecutionContext::subscribe(raw.db.address());
b.iter(|| drop(black_box(query.eval::<BsatnFormat>(ctx, &raw.db, &tx, None))))
b.iter(|| {
drop(black_box(query.eval::<BsatnFormat>(
ctx,
&raw.db,
&tx,
None,
Compression::Brotli,
)))
})
});
};

Expand Down
1 change: 1 addition & 0 deletions crates/client-api-messages/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ bytestring.workspace = true
brotli.workspace = true
chrono = { workspace = true, features = ["serde"] }
enum-as-inner.workspace = true
flate2.workspace = true
Centril marked this conversation as resolved.
Show resolved Hide resolved
serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
smallvec.workspace = true
Expand Down
79 changes: 60 additions & 19 deletions crates/client-api-messages/src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use spacetimedb_sats::{
SpacetimeType,
};
use std::{
io::{self, Read as _},
io::{self, Read as _, Write as _},
sync::Arc,
};

Expand Down Expand Up @@ -74,7 +74,7 @@ pub trait WebsocketFormat: Sized {

/// Convert a `QueryUpdate` into `Self::QueryUpdate`.
/// This allows some formats to e.g., compress the update.
fn into_query_update(qu: QueryUpdate<Self>) -> Self::QueryUpdate;
fn into_query_update(qu: QueryUpdate<Self>, compression: Compression) -> Self::QueryUpdate;
}

/// Messages sent from the client to the server.
Expand Down Expand Up @@ -165,12 +165,15 @@ pub struct OneOffQuery {
pub query_string: Box<str>,
}

/// The tag recognized by ghe host and SDKs to mean no compression of a [`ServerMessage`].
/// The tag recognized by the host and SDKs to mean no compression of a [`ServerMessage`].
pub const SERVER_MSG_COMPRESSION_TAG_NONE: u8 = 0;

/// The tag recognized by the host and SDKs to mean brotli compression of a [`ServerMessage`].
pub const SERVER_MSG_COMPRESSION_TAG_BROTLI: u8 = 1;

/// The tag recognized by the host and SDKs to mean brotli compression of a [`ServerMessage`].
Centril marked this conversation as resolved.
Show resolved Hide resolved
pub const SERVER_MSG_COMPRESSION_TAG_GZIP: u8 = 2;

/// Messages sent from the server to the client.
#[derive(SpacetimeType, derive_more::From)]
#[sats(crate = spacetimedb_lib)]
Expand Down Expand Up @@ -357,13 +360,21 @@ impl<F: WebsocketFormat> TableUpdate<F> {
pub enum CompressableQueryUpdate<F: WebsocketFormat> {
Uncompressed(QueryUpdate<F>),
Brotli(Bytes),
Gzip(Bytes),
}

impl CompressableQueryUpdate<BsatnFormat> {
pub fn maybe_decompress(self) -> QueryUpdate<BsatnFormat> {
match self {
Self::Uncompressed(qu) => qu,
Self::Brotli(bytes) => brotli_decompress_qu(&bytes),
Self::Brotli(bytes) => {
let bytes = brotli_decompress(&bytes).unwrap();
bsatn::from_slice(&bytes).unwrap()
}
Self::Gzip(bytes) => {
let bytes = gzip_decompress(&bytes).unwrap();
bsatn::from_slice(&bytes).unwrap()
}
}
}
}
Expand Down Expand Up @@ -456,7 +467,7 @@ impl WebsocketFormat for JsonFormat {

type QueryUpdate = QueryUpdate<Self>;

fn into_query_update(qu: QueryUpdate<Self>) -> Self::QueryUpdate {
fn into_query_update(qu: QueryUpdate<Self>, _: Compression) -> Self::QueryUpdate {
qu
}
}
Expand Down Expand Up @@ -499,27 +510,50 @@ impl WebsocketFormat for BsatnFormat {

type QueryUpdate = CompressableQueryUpdate<Self>;

fn into_query_update(qu: QueryUpdate<Self>) -> Self::QueryUpdate {
fn into_query_update(qu: QueryUpdate<Self>, compression: Compression) -> Self::QueryUpdate {
let qu_len_would_have_been = bsatn::to_len(&qu).unwrap();

if should_compress(qu_len_would_have_been) {
let bytes = bsatn::to_vec(&qu).unwrap();
let mut out = Vec::new();
brotli_compress(&bytes, &mut out);
CompressableQueryUpdate::Brotli(out.into())
} else {
CompressableQueryUpdate::Uncompressed(qu)
match decide_compression(qu_len_would_have_been, compression) {
Compression::None => CompressableQueryUpdate::Uncompressed(qu),
Compression::Brotli => {
let bytes = bsatn::to_vec(&qu).unwrap();
let mut out = Vec::new();
brotli_compress(&bytes, &mut out);
CompressableQueryUpdate::Brotli(out.into())
}
Compression::Gzip => {
let bytes = bsatn::to_vec(&qu).unwrap();
let mut out = Vec::new();
gzip_compress(&bytes, &mut out);
CompressableQueryUpdate::Gzip(out.into())
}
}
}
}

pub fn should_compress(len: usize) -> bool {
/// The threshold at which we start to compress messages.
/// A specification of either a desired or decided compression algorithm.
#[derive(serde::Deserialize, Default, PartialEq, Eq, Clone, Copy, Hash, Debug)]
pub enum Compression {
/// No compression ever.
None,
/// Compress using brotli if a certain size threshold was met.
#[default]
Brotli,
/// Compress using gzip if a certain size threshold was met.
Gzip,
}

pub fn decide_compression(len: usize, compression: Compression) -> Compression {
/// The threshold beyond which we start to compress messages.
/// 1KiB was chosen without measurement.
/// TODO(perf): measure!
const COMPRESS_THRESHOLD: usize = 1024;

len <= COMPRESS_THRESHOLD
if len > COMPRESS_THRESHOLD {
compression
} else {
Compression::None
}
}

pub fn brotli_compress(bytes: &[u8], out: &mut Vec<u8>) {
Expand Down Expand Up @@ -560,9 +594,16 @@ pub fn brotli_decompress(bytes: &[u8]) -> Result<Vec<u8>, io::Error> {
Ok(decompressed)
}

pub fn brotli_decompress_qu(bytes: &[u8]) -> QueryUpdate<BsatnFormat> {
let bytes = brotli_decompress(bytes).unwrap();
bsatn::from_slice(&bytes).unwrap()
pub fn gzip_compress(bytes: &[u8], out: &mut Vec<u8>) {
let mut encoder = flate2::write::GzEncoder::new(out, flate2::Compression::fast());
encoder.write_all(bytes).unwrap();
encoder.finish().expect("Failed to gzip compress `bytes`");
}

pub fn gzip_decompress(bytes: &[u8]) -> Result<Vec<u8>, io::Error> {
let mut decompressed = Vec::new();
let _ = flate2::read::GzDecoder::new(bytes).read(&mut decompressed)?;
Ok(decompressed)
}

type RowSize = u16;
Expand Down
15 changes: 11 additions & 4 deletions crates/client-api/src/routes/subscribe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use spacetimedb::client::{ClientActorId, ClientConnection, DataMessage, MessageH
use spacetimedb::host::NoSuchModule;
use spacetimedb::util::also_poll;
use spacetimedb::worker_metrics::WORKER_METRICS;
use spacetimedb_client_api_messages::websocket::Compression;
use spacetimedb_lib::address::AddressForUrl;
use spacetimedb_lib::Address;
use std::time::Instant;
Expand All @@ -42,6 +43,7 @@ pub struct SubscribeParams {
#[derive(Deserialize)]
pub struct SubscribeQueryParams {
pub client_address: Option<AddressForUrl>,
pub compression: Option<Compression>,
}

// TODO: is this a reasonable way to generate client addresses?
Expand All @@ -55,7 +57,10 @@ pub fn generate_random_address() -> Address {
pub async fn handle_websocket<S>(
State(ctx): State<S>,
Path(SubscribeParams { name_or_address }): Path<SubscribeParams>,
Query(SubscribeQueryParams { client_address }): Query<SubscribeQueryParams>,
Query(SubscribeQueryParams {
client_address,
compression,
}): Query<SubscribeQueryParams>,
forwarded_for: Option<TypedHeader<XForwardedFor>>,
Extension(auth): Extension<SpacetimeAuth>,
ws: WebSocketUpgrade,
Expand All @@ -80,6 +85,7 @@ where
ws.select_protocol([(BIN_PROTOCOL, Protocol::Binary), (TEXT_PROTOCOL, Protocol::Text)]);

let protocol = protocol.ok_or((StatusCode::BAD_REQUEST, "no valid protocol selected"))?;
let compression = compression.unwrap_or_default();

// TODO: Should also maybe refactor the code and the protocol to allow a single websocket
// to connect to multiple modules
Expand Down Expand Up @@ -131,7 +137,8 @@ where
}

let actor = |client, sendrx| ws_client_actor(client, ws, sendrx);
let client = match ClientConnection::spawn(client_id, protocol, replica_id, module_rx, actor).await {
let client = match ClientConnection::spawn(client_id, protocol, compression, replica_id, module_rx, actor).await
{
Ok(s) => s,
Err(e) => {
log::warn!("ModuleHost died while we were connecting: {e:#}");
Expand Down Expand Up @@ -259,7 +266,7 @@ async fn ws_client_actor_inner(
let workload = msg.workload();
let num_rows = msg.num_rows();

let msg = datamsg_to_wsmsg(serialize(msg, client.protocol));
let msg = datamsg_to_wsmsg(serialize(msg, client.protocol, client.compression));

// These metrics should be updated together,
// or not at all.
Expand Down Expand Up @@ -347,7 +354,7 @@ async fn ws_client_actor_inner(
if let Err(e) = res {
if let MessageHandleError::Execution(err) = e {
log::error!("{err:#}");
let msg = serialize(err, client.protocol);
let msg = serialize(err, client.protocol, client.compression);
if let Err(error) = ws.send(datamsg_to_wsmsg(msg)).await {
log::warn!("Websocket send error: {error}")
}
Expand Down
6 changes: 5 additions & 1 deletion crates/core/src/client/client_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::util::prometheus_handle::IntGaugeExt;
use crate::worker_metrics::WORKER_METRICS;
use derive_more::From;
use futures::prelude::*;
use spacetimedb_client_api_messages::websocket::FormatSwitch;
use spacetimedb_client_api_messages::websocket::{Compression, FormatSwitch};
use spacetimedb_lib::identity::RequestId;
use tokio::sync::{mpsc, oneshot, watch};
use tokio::task::AbortHandle;
Expand All @@ -36,6 +36,7 @@ impl Protocol {
pub struct ClientConnectionSender {
pub id: ClientActorId,
pub protocol: Protocol,
pub compression: Compression,
sendtx: mpsc::Sender<SerializableMessage>,
abort_handle: AbortHandle,
cancelled: AtomicBool,
Expand All @@ -61,6 +62,7 @@ impl ClientConnectionSender {
Self {
id,
protocol,
compression: Compression::Brotli,
sendtx,
abort_handle,
cancelled: AtomicBool::new(false),
Expand Down Expand Up @@ -143,6 +145,7 @@ impl ClientConnection {
pub async fn spawn<F, Fut>(
id: ClientActorId,
protocol: Protocol,
compression: Compression,
replica_id: u64,
mut module_rx: watch::Receiver<ModuleHost>,
actor: F,
Expand Down Expand Up @@ -178,6 +181,7 @@ impl ClientConnection {
let sender = Arc::new(ClientConnectionSender {
id,
protocol,
compression,
sendtx,
abort_handle,
cancelled: AtomicBool::new(false),
Expand Down
32 changes: 22 additions & 10 deletions crates/core/src/client/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use crate::host::ArgsTuple;
use crate::messages::websocket as ws;
use derive_more::From;
use spacetimedb_client_api_messages::websocket::{
BsatnFormat, FormatSwitch, JsonFormat, WebsocketFormat, SERVER_MSG_COMPRESSION_TAG_BROTLI,
SERVER_MSG_COMPRESSION_TAG_NONE,
BsatnFormat, Compression, FormatSwitch, JsonFormat, WebsocketFormat, SERVER_MSG_COMPRESSION_TAG_BROTLI,
SERVER_MSG_COMPRESSION_TAG_GZIP, SERVER_MSG_COMPRESSION_TAG_NONE,
};
use spacetimedb_lib::identity::RequestId;
use spacetimedb_lib::ser::serde::SerializeWrapper;
Expand All @@ -28,8 +28,13 @@ pub(super) type SwitchedServerMessage = FormatSwitch<ws::ServerMessage<BsatnForm

/// Serialize `msg` into a [`DataMessage`] containing a [`ws::ServerMessage`].
///
/// If `protocol` is [`Protocol::Binary`], the message will be compressed by this method.
pub fn serialize(msg: impl ToProtocol<Encoded = SwitchedServerMessage>, protocol: Protocol) -> DataMessage {
/// If `protocol` is [`Protocol::Binary`],
/// the message will be conditionally compressed by this method according to `compression`.
pub fn serialize(
msg: impl ToProtocol<Encoded = SwitchedServerMessage>,
protocol: Protocol,
compression: Compression,
) -> DataMessage {
// TODO(centril, perf): here we are allocating buffers only to throw them away eventually.
// Consider pooling these allocations so that we reuse them.
match msg.to_protocol(protocol) {
Expand All @@ -40,12 +45,19 @@ pub fn serialize(msg: impl ToProtocol<Encoded = SwitchedServerMessage>, protocol
bsatn::to_writer(&mut msg_bytes, &msg).unwrap();

// Conditionally compress the message.
let msg_bytes = if ws::should_compress(msg_bytes[1..].len()) {
let mut out = vec![SERVER_MSG_COMPRESSION_TAG_BROTLI];
ws::brotli_compress(&msg_bytes[1..], &mut out);
out
} else {
msg_bytes
let srv_msg = &msg_bytes[1..];
let msg_bytes = match ws::decide_compression(srv_msg.len(), compression) {
Compression::None => msg_bytes,
Compression::Brotli => {
let mut out = vec![SERVER_MSG_COMPRESSION_TAG_BROTLI];
ws::brotli_compress(srv_msg, &mut out);
out
}
Compression::Gzip => {
let mut out = vec![SERVER_MSG_COMPRESSION_TAG_GZIP];
ws::gzip_compress(srv_msg, &mut out);
out
}
};
msg_bytes.into()
}
Expand Down
6 changes: 3 additions & 3 deletions crates/core/src/host/module_host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use indexmap::IndexSet;
use itertools::Itertools;
use smallvec::SmallVec;
use spacetimedb_client_api_messages::timestamp::Timestamp;
use spacetimedb_client_api_messages::websocket::{QueryUpdate, WebsocketFormat};
use spacetimedb_client_api_messages::websocket::{Compression, QueryUpdate, WebsocketFormat};
use spacetimedb_data_structures::error_stream::ErrorStream;
use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap};
use spacetimedb_lib::identity::{AuthCtx, RequestId};
Expand Down Expand Up @@ -124,12 +124,12 @@ impl UpdatesRelValue<'_> {
!(self.deletes.is_empty() && self.inserts.is_empty())
}

pub fn encode<F: WebsocketFormat>(&self) -> (F::QueryUpdate, u64) {
pub fn encode<F: WebsocketFormat>(&self, compression: Compression) -> (F::QueryUpdate, u64) {
let (deletes, nr_del) = F::encode_list(self.deletes.iter());
let (inserts, nr_ins) = F::encode_list(self.inserts.iter());
let num_rows = nr_del + nr_ins;
let qu = QueryUpdate { deletes, inserts };
let cqu = F::into_query_update(qu);
let cqu = F::into_query_update(qu, compression);
(cqu, num_rows)
}
}
Expand Down
Loading
Loading