diff --git a/Cargo.toml b/Cargo.toml index 4f96e962..a58f22a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,7 @@ async-trait = "0.1" connection-string = "0.2" num-traits = "0.2" uuid = "1.0" +bitflags = "2.4.0" [target.'cfg(windows)'.dependencies] winauth = { version = "0.0.4", optional = true } @@ -179,18 +180,7 @@ indoc = "1.0.7" features = ["all", "docs"] [features] -all = [ - "chrono", - "time", - "tds73", - "sql-browser-async-std", - "sql-browser-tokio", - "sql-browser-smol", - "integrated-auth-gssapi", - "rust_decimal", - "bigdecimal", - "native-tls", -] +all = ["chrono", "time", "tds73", "sql-browser-async-std", "sql-browser-tokio", "sql-browser-smol", "integrated-auth-gssapi", "rust_decimal", "bigdecimal", "native-tls"] default = ["tds73", "winauth", "native-tls"] tds73 = [] docs = [] diff --git a/src/bulk_options.rs b/src/bulk_options.rs new file mode 100644 index 00000000..40925be9 --- /dev/null +++ b/src/bulk_options.rs @@ -0,0 +1,38 @@ +use bitflags::bitflags; + +bitflags! { + /// Options for MS Sql Bulk Insert + /// see also: https://learn.microsoft.com/en-us/dotnet/api/system.data.sqlclient.sqlbulkcopyoptions?view=dotnet-plat-ext-7.0#fields + pub struct SqlBulkCopyOptions: u32 { + /// Default options + const Default = 0b00000000; + /// Preserve source identity values. When not specified, identity values are assigned by the destination. + const KeepIdentity = 0b00000001; + /// Check constraints while data is being inserted. By default, constraints are not checked. + const CheckConstraints = 0b00000010; + /// Obtain a bulk update lock for the duration of the bulk copy operation. When not specified, row locks are used. + const TableLock = 0b00000100; + /// Preserve null values in the destination table regardless of the settings for default values. When not specified, null values are replaced by default values where applicable. + const KeepNulls = 0b00001000; + /// When specified, cause the server to fire the insert triggers for the rows being inserted into the database. + const FireTriggers = 0b00010000; + } +} + +impl Default for SqlBulkCopyOptions { + fn default() -> Self { + SqlBulkCopyOptions::Default + } +} + +/// The sort order of a column, used for bulk insert +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum SortOrder { + /// Ascending order + Ascending, + /// Descending order + Descending +} + +/// An order hint for bulk insert +pub type ColumOrderHint<'a> = (&'a str, SortOrder); diff --git a/src/client.rs b/src/client.rs index 688721d1..bdd79412 100644 --- a/src/client.rs +++ b/src/client.rs @@ -14,6 +14,7 @@ pub use auth::*; pub use config::*; pub(crate) use connection::*; +use crate::bulk_options::{SqlBulkCopyOptions, ColumOrderHint}; use crate::tds::stream::ReceivedToken; use crate::{ result::ExecuteResult, @@ -251,6 +252,14 @@ impl Client { Ok(result) } + #[doc(hidden)] // deprecated for bulk_insert_with_options + pub async fn bulk_insert<'a>( + &'a mut self, + table: &'a str, + ) -> crate::Result> { + return self.bulk_insert_with_options(table, &[], Default::default(), &[]).await; + } + /// Execute a `BULK INSERT` statement, efficiantly storing a large number of /// rows to a specified table. Note: make sure the input row follows the same /// schema as the table, otherwise calling `send()` will return an error. @@ -296,15 +305,24 @@ impl Client { /// # Ok(()) /// # } /// ``` - pub async fn bulk_insert<'a>( + pub async fn bulk_insert_with_options<'a>( &'a mut self, table: &'a str, + column_names: &'a [&'a str], + options: SqlBulkCopyOptions, + order_hints: &'a [ColumOrderHint<'a>], ) -> crate::Result> { // Start the bulk request self.connection.flush_stream().await?; // retrieve column metadata from server - let query = format!("SELECT TOP 0 * FROM {}", table); + + let cols_sql = match column_names.len() { + 0 => "*".to_owned(), + _ => column_names.iter().map(|c| format!("[{}]", c)).join(", "), + }; + + let query = format!("SELECT TOP 0 {} FROM {}", cols_sql, table); let req = BatchRequest::new(query, self.connection.context().transaction_descriptor()); @@ -334,8 +352,53 @@ impl Client { self.connection.flush_stream().await?; let col_data = columns.iter().map(|c| format!("{}", c)).join(", "); - let query = format!("INSERT BULK {} ({})", table, col_data); - + let mut query = format!("INSERT BULK {} ({})", table, col_data); + if options.bits() > 0 || order_hints.len() > 0 { + let mut add_separator = false; + query.push_str(" WITH ("); + if options.contains(SqlBulkCopyOptions::KeepNulls) { + query.push_str("KEEP_NULLS"); + add_separator = true; + } + if options.contains(SqlBulkCopyOptions::TableLock) { + if add_separator { + query.push_str(", "); + } + query.push_str("TABLOCK"); + add_separator = true; + } + if options.contains(SqlBulkCopyOptions::CheckConstraints) { + if add_separator { + query.push_str(", "); + } + query.push_str("CHECK_CONSTRAINTS"); + add_separator = true; + } + if options.contains(SqlBulkCopyOptions::FireTriggers) { + if add_separator { + query.push_str(", "); + } + query.push_str("FIRE_TRIGGERS"); + add_separator = true; + } + if order_hints.len() > 0 { + if add_separator { + query.push_str(", "); + } + query.push_str("ORDER ("); + query.push_str( + &order_hints + .iter() + .map(|(col, order)| format!("{} {}", col, match order { + crate::bulk_options::SortOrder::Ascending => "ASC", + crate::bulk_options::SortOrder::Descending => "DESC", + })) + .join(", "), + ); + query.push_str(")"); + } + query.push_str(")"); + } let req = BatchRequest::new(query, self.connection.context().transaction_descriptor()); let id = self.connection.context_mut().next_packet_id(); diff --git a/src/lib.rs b/src/lib.rs index 882f5ad3..fd6665db 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -268,6 +268,7 @@ mod row; mod tds; mod sql_browser; +mod bulk_options; pub use client::{AuthMethod, Client, Config}; pub(crate) use error::Error; @@ -284,7 +285,7 @@ pub use tds::{ }; pub use to_sql::{IntoSql, ToSql}; pub use uuid::Uuid; - +pub use bulk_options::{SqlBulkCopyOptions, SortOrder, ColumOrderHint}; use sql_read_bytes::*; use tds::codec::*; diff --git a/src/tds/codec/column_data.rs b/src/tds/codec/column_data.rs index 6df30173..816f7075 100644 --- a/src/tds/codec/column_data.rs +++ b/src/tds/codec/column_data.rs @@ -311,8 +311,8 @@ impl<'a> Encode> for ColumnData<'a> { &mut bytes, true, ); - if let encoding_rs::EncoderResult::Unmappable(_) = res { - return Err(crate::Error::Encoding("unrepresentable character".into())); + if let encoding_rs::EncoderResult::Unmappable(c) = res { + return Err(crate::Error::Encoding(format!("unrepresentable character:{}", c).into())); } if bytes.len() > vlc.len() { @@ -720,12 +720,7 @@ mod tests { .await .expect("decode must succeed"); - assert_eq!(nd, d); - - reader - .read_u8() - .await - .expect_err("decode must consume entire buffer"); + assert_eq!(nd, d) } #[tokio::test] diff --git a/src/tds/codec/token/token_col_metadata.rs b/src/tds/codec/token/token_col_metadata.rs index 53ffdf1c..6d49938d 100644 --- a/src/tds/codec/token/token_col_metadata.rs +++ b/src/tds/codec/token/token_col_metadata.rs @@ -25,7 +25,7 @@ pub struct MetaDataColumn<'a> { impl<'a> Display for MetaDataColumn<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{} ", self.col_name)?; + write!(f, "[{}] ", self.col_name)?; match &self.base.ty { TypeInfo::FixedLen(fixed) => match fixed {