From 76ba349b90719724e06bccd1639c724711959ac9 Mon Sep 17 00:00:00 2001 From: Sijie Yang Date: Fri, 12 Jul 2024 19:43:49 +0800 Subject: [PATCH] Support customized connection id generator --- cbindgen.toml | 2 +- include/tquic.h | 42 ++++++++++++++++++++++++++++++++ src/connection/connection.rs | 4 +-- src/endpoint.rs | 7 +++++- src/ffi.rs | 47 ++++++++++++++++++++++++++++++++++++ src/lib.rs | 16 +++--------- 6 files changed, 101 insertions(+), 17 deletions(-) diff --git a/cbindgen.toml b/cbindgen.toml index d9f8670b..54751f03 100644 --- a/cbindgen.toml +++ b/cbindgen.toml @@ -19,7 +19,7 @@ sys_includes = ["sys/socket.h", "sys/types.h"] includes = ["openssl/ssl.h", "tquic_def.h"] [export] -exclude = ["MAX_CID_LEN", "MIN_CLIENT_INITIAL_LEN", "VINT_MAX"] +exclude = ["MIN_CLIENT_INITIAL_LEN", "VINT_MAX"] [export.rename] "Config" = "quic_config_t" diff --git a/include/tquic.h b/include/tquic.h index f2eebfd6..58896760 100644 --- a/include/tquic.h +++ b/include/tquic.h @@ -22,6 +22,12 @@ */ #define QUIC_VERSION_V1 1 +/** + * The Connection ID MUST NOT exceed 20 bytes in QUIC version 1. + * See RFC 9000 Section 17.2 + */ +#define MAX_CID_LEN 20 + /** * Available congestion control algorithms. */ @@ -224,6 +230,34 @@ typedef struct quic_packet_send_methods_t { typedef void *quic_packet_send_context_t; +/** + * Connection Id is an identifier used to identify a QUIC connection + * at an endpoint. + */ +typedef struct ConnectionId { + /** + * length of cid + */ + uint8_t len; + /** + * octets of cid + */ + uint8_t data[MAX_CID_LEN]; +} ConnectionId; + +typedef struct ConnectionIdGeneratorMethods { + /** + * Generate a new CID + */ + struct ConnectionId (*generate)(void *gctx); + /** + * Return the length of a CID + */ + uint8_t (*cid_len)(void *gctx); +} ConnectionIdGeneratorMethods; + +typedef void *ConnectionIdGeneratorContext; + /** * Meta information of an incoming packet. */ @@ -707,6 +741,14 @@ struct quic_endpoint_t *quic_endpoint_new(struct quic_config_t *config, */ void quic_endpoint_free(struct quic_endpoint_t *endpoint); +/** + * Set the connection id generator for the endpoint. + * By default, the random connection id generator is used. + */ +void quic_endpoint_set_cid_generator(struct quic_endpoint_t *endpoint, + const struct ConnectionIdGeneratorMethods *cid_gen_methods, + ConnectionIdGeneratorContext cid_gen_ctx); + /** * Create a client connection. * If success, the output parameter `index` carrys the index of the connection. diff --git a/src/connection/connection.rs b/src/connection/connection.rs index 3c03134b..258b80d1 100644 --- a/src/connection/connection.rs +++ b/src/connection/connection.rs @@ -4266,8 +4266,8 @@ pub(crate) mod tests { server_config: &mut Config, server_name: &str, ) -> Result { - let mut cli_cid_gen = RandomConnectionIdGenerator::new(client_config.cid_len, None); - let mut srv_cid_gen = RandomConnectionIdGenerator::new(server_config.cid_len, None); + let mut cli_cid_gen = RandomConnectionIdGenerator::new(client_config.cid_len); + let mut srv_cid_gen = RandomConnectionIdGenerator::new(server_config.cid_len); let client_scid = cli_cid_gen.generate(); let server_scid = srv_cid_gen.generate(); let client_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 9443); diff --git a/src/endpoint.rs b/src/endpoint.rs index 48498dee..4f4cacc0 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -112,7 +112,6 @@ impl Endpoint { ) -> Self { let cid_gen = Box::new(crate::RandomConnectionIdGenerator { cid_len: config.cid_len, - cid_lifetime: None, }); let trace_id = if is_server { "SERVER" } else { "CLIENT" }; let buffer = PacketBuffer::new(config.zerortt_buffer_size); @@ -802,6 +801,12 @@ impl Endpoint { self.conns.clear(); } + /// Set the connection id generator + /// By default, the RandomConnectionIdGenerator is used. + pub fn set_cid_generator(&mut self, cid_gen: Box) { + self.cid_gen = cid_gen; + } + /// Set the unique trace id for the endpoint pub fn set_trace_id(&mut self, trace_id: String) { self.trace_id = trace_id diff --git a/src/ffi.rs b/src/ffi.rs index 86562f3a..9ff8a1bd 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -725,6 +725,21 @@ pub extern "C" fn quic_endpoint_free(endpoint: *mut Endpoint) { }; } +/// Set the connection id generator for the endpoint. +/// By default, the random connection id generator is used. +#[no_mangle] +pub extern "C" fn quic_endpoint_set_cid_generator( + endpoint: &mut Endpoint, + cid_gen_methods: *const ConnectionIdGeneratorMethods, + cid_gen_ctx: ConnectionIdGeneratorContext, +) { + let cid_generator = Box::new(ConnectionIdGenerator { + methods: cid_gen_methods, + context: cid_gen_ctx, + }); + endpoint.set_cid_generator(cid_generator); +} + /// Create a client connection. /// If success, the output parameter `index` carrys the index of the connection. /// Note: The `config` specific to the endpoint or server is irrelevant and will be disregarded. @@ -1773,6 +1788,38 @@ pub struct PacketOutSpec { dst_addr_len: socklen_t, } +#[repr(C)] +pub struct ConnectionIdGeneratorMethods { + /// Generate a new CID + pub generate: fn(gctx: *mut c_void) -> ConnectionId, + + /// Return the length of a CID + pub cid_len: fn(gctx: *mut c_void) -> u8, +} + +#[repr(transparent)] +pub struct ConnectionIdGeneratorContext(*mut c_void); + +/// cbindgen:no-export +#[repr(C)] +pub struct ConnectionIdGenerator { + pub methods: *const ConnectionIdGeneratorMethods, + pub context: ConnectionIdGeneratorContext, +} + +impl crate::ConnectionIdGenerator for ConnectionIdGenerator { + /// Generate a new CID + fn generate(&mut self) -> ConnectionId { + unsafe { ((*self.methods).generate)(self.context.0) } + } + + /// Return the length of a CID + fn cid_len(&self) -> usize { + let cid_len = unsafe { ((*self.methods).cid_len)(self.context.0) }; + cid_len as usize + } +} + /// Create default config for HTTP3. #[no_mangle] pub extern "C" fn http3_config_new() -> *mut Http3Config { diff --git a/src/lib.rs b/src/lib.rs index c9d47d22..26a1be96 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -155,6 +155,7 @@ pub type Result = std::result::Result; /// Connection Id is an identifier used to identify a QUIC connection /// at an endpoint. +#[repr(C)] #[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash, Default)] pub struct ConnectionId { /// length of cid @@ -214,9 +215,6 @@ pub trait ConnectionIdGenerator { /// Return the length of a CID fn cid_len(&self) -> usize; - /// Return the lifetime of CID - fn cid_lifetime(&self) -> Option; - /// Generate a new CID and associated reset token. fn generate_cid_and_token(&mut self, reset_token_key: &hmac::Key) -> (ConnectionId, u128) { let scid = self.generate(); @@ -229,14 +227,12 @@ pub trait ConnectionIdGenerator { #[derive(Debug, Clone, Copy)] pub struct RandomConnectionIdGenerator { cid_len: usize, - cid_lifetime: Option, } impl RandomConnectionIdGenerator { - pub fn new(cid_len: usize, cid_lifetime: Option) -> Self { + pub fn new(cid_len: usize) -> Self { Self { cid_len: cmp::min(cid_len, MAX_CID_LEN), - cid_lifetime, } } } @@ -251,10 +247,6 @@ impl ConnectionIdGenerator for RandomConnectionIdGenerator { fn cid_len(&self) -> usize { self.cid_len } - - fn cid_lifetime(&self) -> Option { - self.cid_lifetime - } } /// Meta information about a packet. @@ -1085,11 +1077,9 @@ mod tests { #[test] fn connection_id() { - let lifetime = Duration::from_secs(3600); - let mut cid_gen = RandomConnectionIdGenerator::new(8, Some(lifetime)); + let mut cid_gen = RandomConnectionIdGenerator::new(8); let cid = cid_gen.generate(); assert_eq!(cid.len(), cid_gen.cid_len()); - assert_eq!(Some(lifetime), cid_gen.cid_lifetime()); let cid = ConnectionId { len: 4,