From 5e3ba54ec4e053d7042a9536a8aa017a22900650 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Thu, 21 Sep 2023 13:13:58 +0200 Subject: [PATCH 1/5] fixes https://github.com/prisma/tiberius/issues/302 --- Cargo.toml | 14 ++------- src/bulk_options.rs | 38 +++++++++++++++++++++++ src/client.rs | 74 ++++++++++++++++++++++++++++++++++++++++++--- src/lib.rs | 3 +- 4 files changed, 111 insertions(+), 18 deletions(-) create mode 100644 src/bulk_options.rs diff --git a/Cargo.toml b/Cargo.toml index 4f96e962..a58f22a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,7 @@ async-trait = "0.1" connection-string = "0.2" num-traits = "0.2" uuid = "1.0" +bitflags = "2.4.0" [target.'cfg(windows)'.dependencies] winauth = { version = "0.0.4", optional = true } @@ -179,18 +180,7 @@ indoc = "1.0.7" features = ["all", "docs"] [features] -all = [ - "chrono", - "time", - "tds73", - "sql-browser-async-std", - "sql-browser-tokio", - "sql-browser-smol", - "integrated-auth-gssapi", - "rust_decimal", - "bigdecimal", - "native-tls", -] +all = ["chrono", "time", "tds73", "sql-browser-async-std", "sql-browser-tokio", "sql-browser-smol", "integrated-auth-gssapi", "rust_decimal", "bigdecimal", "native-tls"] default = ["tds73", "winauth", "native-tls"] tds73 = [] docs = [] diff --git a/src/bulk_options.rs b/src/bulk_options.rs new file mode 100644 index 00000000..40925be9 --- /dev/null +++ b/src/bulk_options.rs @@ -0,0 +1,38 @@ +use bitflags::bitflags; + +bitflags! { + /// Options for MS Sql Bulk Insert + /// see also: https://learn.microsoft.com/en-us/dotnet/api/system.data.sqlclient.sqlbulkcopyoptions?view=dotnet-plat-ext-7.0#fields + pub struct SqlBulkCopyOptions: u32 { + /// Default options + const Default = 0b00000000; + /// Preserve source identity values. When not specified, identity values are assigned by the destination. + const KeepIdentity = 0b00000001; + /// Check constraints while data is being inserted. By default, constraints are not checked. + const CheckConstraints = 0b00000010; + /// Obtain a bulk update lock for the duration of the bulk copy operation. When not specified, row locks are used. + const TableLock = 0b00000100; + /// Preserve null values in the destination table regardless of the settings for default values. When not specified, null values are replaced by default values where applicable. + const KeepNulls = 0b00001000; + /// When specified, cause the server to fire the insert triggers for the rows being inserted into the database. + const FireTriggers = 0b00010000; + } +} + +impl Default for SqlBulkCopyOptions { + fn default() -> Self { + SqlBulkCopyOptions::Default + } +} + +/// The sort order of a column, used for bulk insert +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum SortOrder { + /// Ascending order + Ascending, + /// Descending order + Descending +} + +/// An order hint for bulk insert +pub type ColumOrderHint<'a> = (&'a str, SortOrder); diff --git a/src/client.rs b/src/client.rs index 688721d1..38ecd79f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -14,6 +14,7 @@ pub use auth::*; pub use config::*; pub(crate) use connection::*; +use crate::bulk_options::{SqlBulkCopyOptions, ColumOrderHint}; use crate::tds::stream::ReceivedToken; use crate::{ result::ExecuteResult, @@ -251,6 +252,14 @@ impl Client { Ok(result) } + #[doc(hidden)] // deprecated for bulk_insert_with_options + pub async fn bulk_insert<'a>( + &'a mut self, + table: &'a str, + ) -> crate::Result> { + return self.bulk_insert_with_options(table, &[], Default::default(), &[]).await; + } + /// Execute a `BULK INSERT` statement, efficiantly storing a large number of /// rows to a specified table. Note: make sure the input row follows the same /// schema as the table, otherwise calling `send()` will return an error. @@ -296,15 +305,24 @@ impl Client { /// # Ok(()) /// # } /// ``` - pub async fn bulk_insert<'a>( + pub async fn bulk_insert_with_options<'a>( &'a mut self, table: &'a str, + column_names: &'a [&'a str], + options: SqlBulkCopyOptions, + order_hints: &'a [ColumOrderHint<'a>], ) -> crate::Result> { // Start the bulk request self.connection.flush_stream().await?; // retrieve column metadata from server - let query = format!("SELECT TOP 0 * FROM {}", table); + + let cols_sql = match column_names.len() { + 0 => "*".to_owned(), + _ => column_names.iter().map(|c| format!("\"{}\"", c)).join(", "), + }; + + let query = format!("SELECT TOP 0 {} FROM {}", cols_sql, table); let req = BatchRequest::new(query, self.connection.context().transaction_descriptor()); @@ -333,9 +351,55 @@ impl Client { .collect(); self.connection.flush_stream().await?; - let col_data = columns.iter().map(|c| format!("{}", c)).join(", "); - let query = format!("INSERT BULK {} ({})", table, col_data); - + let col_data = columns.iter().map(|c| format!("\"{}\"", c)).join(", "); + let mut query = format!("INSERT BULK {} ({})", table, col_data); + if options.bits() > 0 || order_hints.len() > 0 { + let mut add_separator = false; + query.push_str(" WITH ("); + if options.contains(SqlBulkCopyOptions::KeepNulls) { + query.push_str("KEEP_NULLS"); + add_separator = true; + } + if options.contains(SqlBulkCopyOptions::TableLock) { + if add_separator { + query.push_str(", "); + } + query.push_str("TABLOCK"); + add_separator = true; + } + if options.contains(SqlBulkCopyOptions::CheckConstraints) { + if add_separator { + query.push_str(", "); + } + query.push_str("CHECK_CONSTRAINTS"); + add_separator = true; + } + if options.contains(SqlBulkCopyOptions::FireTriggers) { + if add_separator { + query.push_str(", "); + } + query.push_str("FIRE_TRIGGERS"); + add_separator = true; + } + if order_hints.len() > 0 { + if add_separator { + query.push_str(", "); + } + query.push_str("ORDER ("); + query.push_str( + &order_hints + .iter() + .map(|(col, order)| format!("{} {}", col, match order { + crate::bulk_options::SortOrder::Ascending => "ASC", + crate::bulk_options::SortOrder::Descending => "DESC", + })) + .join(", "), + ); + query.push_str(")"); + } + query.push_str(")"); + query.push_str(" WITH (KEEPIDENTITY)"); + } let req = BatchRequest::new(query, self.connection.context().transaction_descriptor()); let id = self.connection.context_mut().next_packet_id(); diff --git a/src/lib.rs b/src/lib.rs index 882f5ad3..fd6665db 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -268,6 +268,7 @@ mod row; mod tds; mod sql_browser; +mod bulk_options; pub use client::{AuthMethod, Client, Config}; pub(crate) use error::Error; @@ -284,7 +285,7 @@ pub use tds::{ }; pub use to_sql::{IntoSql, ToSql}; pub use uuid::Uuid; - +pub use bulk_options::{SqlBulkCopyOptions, SortOrder, ColumOrderHint}; use sql_read_bytes::*; use tds::codec::*; From 76a14b37bad33af76ce3b4cb7091c83c6241d831 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Thu, 21 Sep 2023 13:24:35 +0200 Subject: [PATCH 2/5] try new escaping fix --- src/client.rs | 4 ++-- src/tds/codec/token/token_col_metadata.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/client.rs b/src/client.rs index 38ecd79f..6f2e69d8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -319,7 +319,7 @@ impl Client { let cols_sql = match column_names.len() { 0 => "*".to_owned(), - _ => column_names.iter().map(|c| format!("\"{}\"", c)).join(", "), + _ => column_names.iter().map(|c| format!("[{}]", c)).join(", "), }; let query = format!("SELECT TOP 0 {} FROM {}", cols_sql, table); @@ -351,7 +351,7 @@ impl Client { .collect(); self.connection.flush_stream().await?; - let col_data = columns.iter().map(|c| format!("\"{}\"", c)).join(", "); + let col_data = columns.iter().map(|c| format!("{}", c)).join(", "); let mut query = format!("INSERT BULK {} ({})", table, col_data); if options.bits() > 0 || order_hints.len() > 0 { let mut add_separator = false; diff --git a/src/tds/codec/token/token_col_metadata.rs b/src/tds/codec/token/token_col_metadata.rs index 53ffdf1c..6d49938d 100644 --- a/src/tds/codec/token/token_col_metadata.rs +++ b/src/tds/codec/token/token_col_metadata.rs @@ -25,7 +25,7 @@ pub struct MetaDataColumn<'a> { impl<'a> Display for MetaDataColumn<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{} ", self.col_name)?; + write!(f, "[{}] ", self.col_name)?; match &self.base.ty { TypeInfo::FixedLen(fixed) => match fixed { From ae619533512e1d41c72a814f027a84eba0e3893c Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Thu, 21 Sep 2023 13:32:00 +0200 Subject: [PATCH 3/5] small fix --- src/client.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/client.rs b/src/client.rs index 6f2e69d8..bdd79412 100644 --- a/src/client.rs +++ b/src/client.rs @@ -398,7 +398,6 @@ impl Client { query.push_str(")"); } query.push_str(")"); - query.push_str(" WITH (KEEPIDENTITY)"); } let req = BatchRequest::new(query, self.connection.context().transaction_descriptor()); let id = self.connection.context_mut().next_packet_id(); From 326b0088f75617dca98fb7b5704218f92fe8dc2b Mon Sep 17 00:00:00 2001 From: descawed Date: Mon, 9 Oct 2023 09:54:15 -0400 Subject: [PATCH 4/5] Remove extra terminator when encoding zero-length values for large varlen columns --- src/tds/codec/column_data.rs | 80 +++++++++++++++++++++++++++++------- 1 file changed, 66 insertions(+), 14 deletions(-) diff --git a/src/tds/codec/column_data.rs b/src/tds/codec/column_data.rs index ada32781..97602ceb 100644 --- a/src/tds/codec/column_data.rs +++ b/src/tds/codec/column_data.rs @@ -341,8 +341,10 @@ impl<'a> Encode> for ColumnData<'a> { dst.put_u32_le(bytes.len() as u32); dst.extend_from_slice(bytes.as_slice()); - // no next blob - dst.put_u32_le(0u32); + if bytes.len() > 0 { + // no next blob + dst.put_u32_le(0u32); + } } } else if vlc.len() < 0xffff { dst.put_u16_le(0xffff); @@ -407,8 +409,10 @@ impl<'a> Encode> for ColumnData<'a> { )); } - // no next blob - dst.put_u32_le(0u32); + if length > 0 { + // no next blob + dst.put_u32_le(0u32); + } let dst: &mut [u8] = dst.borrow_mut(); let mut dst = &mut dst[len_pos..]; @@ -463,8 +467,10 @@ impl<'a> Encode> for ColumnData<'a> { dst.put_u16_le(chr); } - // PLP_TERMINATOR - dst.put_u32_le(0); + if length > 0 { + // PLP_TERMINATOR + dst.put_u32_le(0); + } let dst: &mut [u8] = dst.borrow_mut(); let bytes = (length * 2).to_le_bytes(); // u32, four bytes @@ -496,8 +502,10 @@ impl<'a> Encode> for ColumnData<'a> { // unknown size dst.put_u64_le(0xfffffffffffffffe); dst.put_u32_le(bytes.len() as u32); - dst.extend(bytes.into_owned()); - dst.put_u32_le(0); + if bytes.len() > 0 { + dst.extend(bytes.into_owned()); + dst.put_u32_le(0); + } } } else if vlc.len() < 0xffff { dst.put_u16_le(0xffff); @@ -519,10 +527,12 @@ impl<'a> Encode> for ColumnData<'a> { dst.put_u64_le(0xfffffffffffffffe_u64); // We'll write in one chunk, length is the whole bytes length dst.put_u32_le(bytes.len() as u32); - // Payload - dst.extend(bytes.into_owned()); - // PLP_TERMINATOR - dst.put_u32_le(0); + if bytes.len() > 0 { + // Payload + dst.extend(bytes.into_owned()); + // PLP_TERMINATOR + dst.put_u32_le(0); + } } (ColumnData::DateTime(opt), Some(TypeInfo::VarLenSized(vlc))) if vlc.r#type() == VarLenType::Datetimen => @@ -705,11 +715,14 @@ mod tests { .encode(&mut buf_with_ti) .expect("encode must succeed"); - let nd = ColumnData::decode(&mut buf.into_sql_read_bytes(), &ti) + let reader = &mut buf.into_sql_read_bytes(); + let nd = ColumnData::decode(reader, &ti) .await .expect("decode must succeed"); - assert_eq!(nd, d) + assert_eq!(nd, d); + + reader.read_u8().await.expect_err("decode must consume entire buffer"); } #[tokio::test] @@ -1025,6 +1038,19 @@ mod tests { .await; } + #[tokio::test] + async fn empty_string_with_varlen_bigvarchar() { + test_round_trip( + TypeInfo::VarLenSized(VarLenContext::new( + VarLenType::BigVarChar, + 0x8ffff, + Some(Collation::new(13632521, 52)), + )), + ColumnData::String(Some("".into())), + ) + .await; + } + #[tokio::test] async fn string_with_varlen_nvarchar() { test_round_trip( @@ -1051,6 +1077,19 @@ mod tests { .await; } + #[tokio::test] + async fn empty_string_with_varlen_nvarchar() { + test_round_trip( + TypeInfo::VarLenSized(VarLenContext::new( + VarLenType::NVarchar, + 0x8ffff, + Some(Collation::new(13632521, 52)), + )), + ColumnData::String(Some("".into())), + ) + .await; + } + #[tokio::test] async fn string_with_varlen_nchar() { test_round_trip( @@ -1157,6 +1196,19 @@ mod tests { .await; } + #[tokio::test] + async fn empty_binary_with_varlen_bigvarbin() { + test_round_trip( + TypeInfo::VarLenSized(VarLenContext::new( + VarLenType::BigVarBin, + 0x8ffff, + Some(Collation::new(13632521, 52)), + )), + ColumnData::Binary(Some(b"".as_slice().into())), + ) + .await; + } + #[tokio::test] async fn datetime_with_varlen_datetimen() { test_round_trip( From 2579d03634fa018f0bd93639ce5ab02715ce60c3 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Wed, 1 May 2024 07:04:15 +0200 Subject: [PATCH 5/5] improve error msg --- src/tds/codec/column_data.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tds/codec/column_data.rs b/src/tds/codec/column_data.rs index 4ed17ae3..816f7075 100644 --- a/src/tds/codec/column_data.rs +++ b/src/tds/codec/column_data.rs @@ -311,8 +311,8 @@ impl<'a> Encode> for ColumnData<'a> { &mut bytes, true, ); - if let encoding_rs::EncoderResult::Unmappable(_) = res { - return Err(crate::Error::Encoding("unrepresentable character".into())); + if let encoding_rs::EncoderResult::Unmappable(c) = res { + return Err(crate::Error::Encoding(format!("unrepresentable character:{}", c).into())); } if bytes.len() > vlc.len() {