Skip to content

Commit

Permalink
feat(python,rust): add encoding parameter to write_csv
Browse files Browse the repository at this point in the history
  • Loading branch information
borchero committed Aug 7, 2023
1 parent 82c1685 commit b0ab82b
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 8 deletions.
2 changes: 2 additions & 0 deletions crates/polars-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ async-trait = { version = "0.1.59", optional = true }
bytes = "1.3.0"
chrono = { version = "0.4", default-features = false, features = ["std"], optional = true }
chrono-tz = { version = "0.8.1", optional = true }
encoding_rs = { version = "0.8.32", features = ["simd-accel"], optional = true }
fast-float = { version = "0.2.0", optional = true }
flate2 = { version = "1", optional = true, default-features = false }
futures = { version = "0.3.25", optional = true }
Expand Down Expand Up @@ -64,6 +65,7 @@ ipc_streaming = ["arrow/io_ipc", "arrow/io_ipc_compression"]
# support for arrow avro parsing
avro = ["arrow/io_avro", "arrow/io_avro_compression"]
csv = ["memmap", "lexical", "polars-core/rows", "lexical-core", "fast-float", "simdutf8"]
csv-encoding = ["encoding_rs"]
decompress = ["flate2/miniz_oxide"]
decompress-fast = ["flate2/zlib-ng"]
dtype-categorical = ["polars-core/dtype-categorical"]
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-io/src/csv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ pub mod read_impl;

mod read;
pub(super) mod splitfields;
#[cfg(feature = "csv-encoding")]
mod transcoding;
pub mod utils;
mod write;
pub(super) mod write_impl;
Expand Down
64 changes: 64 additions & 0 deletions crates/polars-io/src/csv/transcoding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use std::io::Write;

use encoding_rs::{Encoder, EncoderResult, Encoding};

pub(super) struct TranscodingWriter<'a, W> {
sink: &'a mut W,
encoder: Encoder,
buffer: [u8; 1024],
}

impl<'a, W> TranscodingWriter<'a, W>
where
W: Write,
{
pub(super) fn new(sink: &'a mut W, encoding: &'static Encoding) -> Self {
return Self {
sink,
encoder: encoding.new_encoder(),
buffer: [0; 1024],
};
}
}

impl<'a, W> Write for TranscodingWriter<'a, W>
where
W: Write,
{
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
// safety: provided buffer is known to be UTF8
let src = unsafe { std::str::from_utf8_unchecked(buf) };
let (result, n_bytes_read, n_bytes_written) = self
.encoder
.encode_from_utf8_without_replacement(src, &mut self.buffer, false);
match result {
EncoderResult::InputEmpty | EncoderResult::OutputFull => self
.sink
.write_all(&mut self.buffer[..n_bytes_written])
.and(Ok(n_bytes_read)),
EncoderResult::Unmappable(c) => Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!("failed to encode character '{}'", c),
)),
}
}

fn flush(&mut self) -> std::io::Result<()> {
if self.encoder.has_pending_state() {
let (result, _, _) =
self.encoder
.encode_from_utf8_without_replacement("", &mut self.buffer, true);
match result {
EncoderResult::OutputFull => {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!("failed to finalize encoding"),
))
}
_ => {}
};
self.sink.write_all(&mut self.buffer[..])?;
}
self.sink.flush()
}
}
71 changes: 67 additions & 4 deletions crates/polars-io/src/csv/write.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#[cfg(feature = "csv-encoding")]
use encoding_rs::{Encoding, UTF_8};

use super::*;

/// Write a DataFrame to csv.
Expand All @@ -10,6 +13,8 @@ pub struct CsvWriter<W: Write> {
options: write_impl::SerializeOptions,
header: bool,
batch_size: usize,
#[cfg(feature = "csv-encoding")]
encoding: Option<&'static Encoding>,
}

impl<W> SerWriter<W> for CsvWriter<W>
Expand All @@ -28,22 +33,52 @@ where
options,
header: true,
batch_size: 1024,
#[cfg(feature = "csv-encoding")]
encoding: None,
}
}

fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()> {
let names = df.get_column_names();
if self.header {
write_impl::write_header(&mut self.buffer, &names, &self.options)?;
match self.encoding {
Some(enc) => {
let mut writer = transcoding::TranscodingWriter::new(&mut self.buffer, enc);
Self::finish_with_writer(
&mut writer,
df,
self.header,
self.batch_size,
&self.options,
)
}
None => Self::finish_with_writer(
&mut self.buffer,
df,
self.header,
self.batch_size,
&self.options,
),
}
write_impl::write(&mut self.buffer, df, self.batch_size, &self.options)
}
}

impl<W> CsvWriter<W>
where
W: Write,
{
fn finish_with_writer(
writer: &mut impl Write,
df: &mut DataFrame,
header: bool,
batch_size: usize,
options: &write_impl::SerializeOptions,
) -> PolarsResult<()> {
let names = df.get_column_names();
if header {
write_impl::write_header(writer, &names, options)?;
}
write_impl::write(writer, df, batch_size, options)
}

/// Set whether to write headers
pub fn has_header(mut self, has_header: bool) -> Self {
self.header = has_header;
Expand All @@ -61,6 +96,34 @@ where
self
}

#[cfg(feature = "csv-encoding")]
/// Set the CSV file's encoding
pub fn with_encoding(mut self, encoding: Option<String>) -> PolarsResult<Self> {
// Try to get encoding from given string
let encoding = encoding.map(|e| {
Encoding::for_label(e.as_bytes()).ok_or(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("unknown encoding '{}'", e),
))
});
match encoding {
Some(Err(err)) => Err(err.into()),
Some(Ok(enc)) => {
// If we obtained an encoding, we only want to set the encoding for non-UTF8
// as any other encoding is much slower.
self.encoding = match enc {
utf8 if utf8 == UTF_8 => None,
_ => Some(enc),
};
Ok(self)
}
None => {
self.encoding = None;
Ok(self)
}
}
}

/// Set the CSV file's date format
pub fn with_date_format(mut self, format: Option<String>) -> Self {
if format.is_some() {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-io/src/csv/write_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ pub(crate) fn write<W: Write>(

for buf in result_buf.drain(..) {
let mut buf = buf?;
let _ = writer.write(&buf)?;
let _ = writer.write_all(&buf)?;
buf.clear();
write_buffer_pool.set(buf);
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ ipc_streaming = ["polars-io", "polars-io/ipc_streaming", "polars-lazy/ipc"]
avro = ["polars-io", "polars-io/avro"]

# support for arrows csv file parsing
csv = ["polars-io", "polars-io/csv", "polars-lazy/csv", "polars-sql/csv"]
csv = ["polars-io", "polars-io/csv", "polars-io/csv-encoding", "polars-lazy/csv", "polars-sql/csv"]

# slower builds
performant = [
Expand Down
29 changes: 28 additions & 1 deletion py-polars/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2373,6 +2373,7 @@ def write_csv(
time_format: str | None = ...,
float_precision: int | None = ...,
null_value: str | None = ...,
encoding: str | None = ...,
) -> str:
...

Expand All @@ -2390,6 +2391,7 @@ def write_csv(
time_format: str | None = ...,
float_precision: int | None = ...,
null_value: str | None = ...,
encoding: str | None = ...,
) -> None:
...

Expand All @@ -2406,6 +2408,7 @@ def write_csv(
time_format: str | None = None,
float_precision: int | None = None,
null_value: str | None = None,
encoding: str | None = None,
) -> str | None:
"""
Write to comma-separated values (CSV) file.
Expand Down Expand Up @@ -2442,6 +2445,9 @@ def write_csv(
``Float64`` datatypes.
null_value
A string representing null values (defaulting to the empty string).
encoding
A string representing the encoding to use in the CSV file. Defaults to
'utf-8'.
Examples
--------
Expand Down Expand Up @@ -2478,8 +2484,9 @@ def write_csv(
time_format,
float_precision,
null_value,
encoding,
)
return str(buffer.getvalue(), encoding="utf-8")
return str(buffer.getvalue(), encoding=encoding or "utf-8")

if isinstance(file, (str, Path)):
file = normalise_filepath(file)
Expand All @@ -2497,6 +2504,7 @@ def write_csv(
time_format,
float_precision,
null_value,
encoding,
)
return None

Expand Down
5 changes: 5 additions & 0 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ impl PyDataFrame {
time_format: Option<String>,
float_precision: Option<usize>,
null_value: Option<String>,
encoding: Option<String>,
) -> PyResult<()> {
let null = null_value.unwrap_or_default();

Expand All @@ -570,6 +571,8 @@ impl PyDataFrame {
.with_time_format(time_format)
.with_float_precision(float_precision)
.with_null_value(null)
.with_encoding(encoding)
.map_err(PyPolarsErr::from)?
.finish(&mut self.df)
.map_err(PyPolarsErr::from)
})?;
Expand All @@ -585,6 +588,8 @@ impl PyDataFrame {
.with_time_format(time_format)
.with_float_precision(float_precision)
.with_null_value(null)
.with_encoding(encoding)
.map_err(PyPolarsErr::from)?
.finish(&mut self.df)
.map_err(PyPolarsErr::from)?;
}
Expand Down
15 changes: 15 additions & 0 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,3 +1377,18 @@ def test_csv_9929() -> None:
f.seek(0)
with pytest.raises(pl.NoDataError):
pl.read_csv(f, skip_rows=10**6)


def test_write_csv_encoding():
df = pl.DataFrame({"a": ["test", "\u201etest\u201d"]})

utf8_out = io.BytesIO()
df.write_csv(utf8_out)
utf8_buf = utf8_out.getvalue()

cp1252_out = io.BytesIO()
df.write_csv(cp1252_out, encoding="cp1252")
cp1252_buf = cp1252_out.getvalue()

assert utf8_buf != cp1252_buf
assert utf8_buf.decode() == cp1252_buf.decode("cp1252")

0 comments on commit b0ab82b

Please sign in to comment.