Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bulk Improvements #312

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ async-trait = "0.1"
connection-string = "0.2"
num-traits = "0.2"
uuid = "1.0"
bitflags = "2.4.0"

[target.'cfg(windows)'.dependencies]
winauth = { version = "0.0.4", optional = true }
Expand Down Expand Up @@ -179,18 +180,7 @@ indoc = "1.0.7"
features = ["all", "docs"]

[features]
all = [
"chrono",
"time",
"tds73",
"sql-browser-async-std",
"sql-browser-tokio",
"sql-browser-smol",
"integrated-auth-gssapi",
"rust_decimal",
"bigdecimal",
"native-tls",
]
all = ["chrono", "time", "tds73", "sql-browser-async-std", "sql-browser-tokio", "sql-browser-smol", "integrated-auth-gssapi", "rust_decimal", "bigdecimal", "native-tls"]
default = ["tds73", "winauth", "native-tls"]
tds73 = []
docs = []
Expand Down
38 changes: 38 additions & 0 deletions src/bulk_options.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use bitflags::bitflags;

bitflags! {
/// Options for MS Sql Bulk Insert
/// see also: https://learn.microsoft.com/en-us/dotnet/api/system.data.sqlclient.sqlbulkcopyoptions?view=dotnet-plat-ext-7.0#fields
pub struct SqlBulkCopyOptions: u32 {
/// Default options
const Default = 0b00000000;
/// Preserve source identity values. When not specified, identity values are assigned by the destination.
const KeepIdentity = 0b00000001;
/// Check constraints while data is being inserted. By default, constraints are not checked.
const CheckConstraints = 0b00000010;
/// Obtain a bulk update lock for the duration of the bulk copy operation. When not specified, row locks are used.
const TableLock = 0b00000100;
/// Preserve null values in the destination table regardless of the settings for default values. When not specified, null values are replaced by default values where applicable.
const KeepNulls = 0b00001000;
/// When specified, cause the server to fire the insert triggers for the rows being inserted into the database.
const FireTriggers = 0b00010000;
}
}

impl Default for SqlBulkCopyOptions {
fn default() -> Self {
SqlBulkCopyOptions::Default
}
}

/// The sort order of a column, used for bulk insert
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SortOrder {
/// Ascending order
Ascending,
/// Descending order
Descending
}

/// An order hint for bulk insert
pub type ColumOrderHint<'a> = (&'a str, SortOrder);
71 changes: 67 additions & 4 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub use auth::*;
pub use config::*;
pub(crate) use connection::*;

use crate::bulk_options::{SqlBulkCopyOptions, ColumOrderHint};
use crate::tds::stream::ReceivedToken;
use crate::{
result::ExecuteResult,
Expand Down Expand Up @@ -251,6 +252,14 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
Ok(result)
}

#[doc(hidden)] // deprecated for bulk_insert_with_options
pub async fn bulk_insert<'a>(
&'a mut self,
table: &'a str,
) -> crate::Result<BulkLoadRequest<'a, S>> {
return self.bulk_insert_with_options(table, &[], Default::default(), &[]).await;
}

/// Execute a `BULK INSERT` statement, efficiantly storing a large number of
/// rows to a specified table. Note: make sure the input row follows the same
/// schema as the table, otherwise calling `send()` will return an error.
Expand Down Expand Up @@ -296,15 +305,24 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
/// # Ok(())
/// # }
/// ```
pub async fn bulk_insert<'a>(
pub async fn bulk_insert_with_options<'a>(
&'a mut self,
table: &'a str,
column_names: &'a [&'a str],
options: SqlBulkCopyOptions,
order_hints: &'a [ColumOrderHint<'a>],
) -> crate::Result<BulkLoadRequest<'a, S>> {
// Start the bulk request
self.connection.flush_stream().await?;

// retrieve column metadata from server
let query = format!("SELECT TOP 0 * FROM {}", table);

let cols_sql = match column_names.len() {
0 => "*".to_owned(),
_ => column_names.iter().map(|c| format!("[{}]", c)).join(", "),
};

let query = format!("SELECT TOP 0 {} FROM {}", cols_sql, table);

let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());

Expand Down Expand Up @@ -334,8 +352,53 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {

self.connection.flush_stream().await?;
let col_data = columns.iter().map(|c| format!("{}", c)).join(", ");
let query = format!("INSERT BULK {} ({})", table, col_data);

let mut query = format!("INSERT BULK {} ({})", table, col_data);
if options.bits() > 0 || order_hints.len() > 0 {
let mut add_separator = false;
query.push_str(" WITH (");
if options.contains(SqlBulkCopyOptions::KeepNulls) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you use bitflags's Flag::iter and join to simplify the whole separator business?

query.push_str("KEEP_NULLS");
add_separator = true;
}
if options.contains(SqlBulkCopyOptions::TableLock) {
if add_separator {
query.push_str(", ");
}
query.push_str("TABLOCK");
add_separator = true;
}
if options.contains(SqlBulkCopyOptions::CheckConstraints) {
if add_separator {
query.push_str(", ");
}
query.push_str("CHECK_CONSTRAINTS");
add_separator = true;
}
if options.contains(SqlBulkCopyOptions::FireTriggers) {
if add_separator {
query.push_str(", ");
}
query.push_str("FIRE_TRIGGERS");
add_separator = true;
}
if order_hints.len() > 0 {
if add_separator {
query.push_str(", ");
}
query.push_str("ORDER (");
query.push_str(
&order_hints
.iter()
.map(|(col, order)| format!("{} {}", col, match order {
crate::bulk_options::SortOrder::Ascending => "ASC",
crate::bulk_options::SortOrder::Descending => "DESC",
}))
.join(", "),
);
query.push_str(")");
}
query.push_str(")");
}
let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
let id = self.connection.context_mut().next_packet_id();

Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ mod row;
mod tds;

mod sql_browser;
mod bulk_options;

pub use client::{AuthMethod, Client, Config};
pub(crate) use error::Error;
Expand All @@ -284,7 +285,7 @@ pub use tds::{
};
pub use to_sql::{IntoSql, ToSql};
pub use uuid::Uuid;

pub use bulk_options::{SqlBulkCopyOptions, SortOrder, ColumOrderHint};
use sql_read_bytes::*;
use tds::codec::*;

Expand Down
80 changes: 66 additions & 14 deletions src/tds/codec/column_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,10 @@ impl<'a> Encode<BytesMutWithTypeInfo<'a>> for ColumnData<'a> {
dst.put_u32_le(bytes.len() as u32);
dst.extend_from_slice(bytes.as_slice());

// no next blob
dst.put_u32_le(0u32);
if bytes.len() > 0 {
// no next blob
dst.put_u32_le(0u32);
}
}
} else if vlc.len() < 0xffff {
dst.put_u16_le(0xffff);
Expand Down Expand Up @@ -407,8 +409,10 @@ impl<'a> Encode<BytesMutWithTypeInfo<'a>> for ColumnData<'a> {
));
}

// no next blob
dst.put_u32_le(0u32);
if length > 0 {
// no next blob
dst.put_u32_le(0u32);
}

let dst: &mut [u8] = dst.borrow_mut();
let mut dst = &mut dst[len_pos..];
Expand Down Expand Up @@ -463,8 +467,10 @@ impl<'a> Encode<BytesMutWithTypeInfo<'a>> for ColumnData<'a> {
dst.put_u16_le(chr);
}

// PLP_TERMINATOR
dst.put_u32_le(0);
if length > 0 {
// PLP_TERMINATOR
dst.put_u32_le(0);
}

let dst: &mut [u8] = dst.borrow_mut();
let bytes = (length * 2).to_le_bytes(); // u32, four bytes
Expand Down Expand Up @@ -496,8 +502,10 @@ impl<'a> Encode<BytesMutWithTypeInfo<'a>> for ColumnData<'a> {
// unknown size
dst.put_u64_le(0xfffffffffffffffe);
dst.put_u32_le(bytes.len() as u32);
dst.extend(bytes.into_owned());
dst.put_u32_le(0);
if bytes.len() > 0 {
dst.extend(bytes.into_owned());
dst.put_u32_le(0);
}
}
} else if vlc.len() < 0xffff {
dst.put_u16_le(0xffff);
Expand All @@ -519,10 +527,12 @@ impl<'a> Encode<BytesMutWithTypeInfo<'a>> for ColumnData<'a> {
dst.put_u64_le(0xfffffffffffffffe_u64);
// We'll write in one chunk, length is the whole bytes length
dst.put_u32_le(bytes.len() as u32);
// Payload
dst.extend(bytes.into_owned());
// PLP_TERMINATOR
dst.put_u32_le(0);
if bytes.len() > 0 {
// Payload
dst.extend(bytes.into_owned());
// PLP_TERMINATOR
dst.put_u32_le(0);
}
}
(ColumnData::DateTime(opt), Some(TypeInfo::VarLenSized(vlc)))
if vlc.r#type() == VarLenType::Datetimen =>
Expand Down Expand Up @@ -705,11 +715,14 @@ mod tests {
.encode(&mut buf_with_ti)
.expect("encode must succeed");

let nd = ColumnData::decode(&mut buf.into_sql_read_bytes(), &ti)
let reader = &mut buf.into_sql_read_bytes();
let nd = ColumnData::decode(reader, &ti)
.await
.expect("decode must succeed");

assert_eq!(nd, d)
assert_eq!(nd, d);

reader.read_u8().await.expect_err("decode must consume entire buffer");
}

#[tokio::test]
Expand Down Expand Up @@ -1025,6 +1038,19 @@ mod tests {
.await;
}

#[tokio::test]
async fn empty_string_with_varlen_bigvarchar() {
test_round_trip(
TypeInfo::VarLenSized(VarLenContext::new(
VarLenType::BigVarChar,
0x8ffff,
Some(Collation::new(13632521, 52)),
)),
ColumnData::String(Some("".into())),
)
.await;
}

#[tokio::test]
async fn string_with_varlen_nvarchar() {
test_round_trip(
Expand All @@ -1051,6 +1077,19 @@ mod tests {
.await;
}

#[tokio::test]
async fn empty_string_with_varlen_nvarchar() {
test_round_trip(
TypeInfo::VarLenSized(VarLenContext::new(
VarLenType::NVarchar,
0x8ffff,
Some(Collation::new(13632521, 52)),
)),
ColumnData::String(Some("".into())),
)
.await;
}

#[tokio::test]
async fn string_with_varlen_nchar() {
test_round_trip(
Expand Down Expand Up @@ -1157,6 +1196,19 @@ mod tests {
.await;
}

#[tokio::test]
async fn empty_binary_with_varlen_bigvarbin() {
test_round_trip(
TypeInfo::VarLenSized(VarLenContext::new(
VarLenType::BigVarBin,
0x8ffff,
Some(Collation::new(13632521, 52)),
)),
ColumnData::Binary(Some(b"".as_slice().into())),
)
.await;
}

#[tokio::test]
async fn datetime_with_varlen_datetimen() {
test_round_trip(
Expand Down
2 changes: 1 addition & 1 deletion src/tds/codec/token/token_col_metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub struct MetaDataColumn<'a> {

impl<'a> Display for MetaDataColumn<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{} ", self.col_name)?;
write!(f, "[{}] ", self.col_name)?;

match &self.base.ty {
TypeInfo::FixedLen(fixed) => match fixed {
Expand Down