diff --git a/crates/polars-io/src/csv/read/options.rs b/crates/polars-io/src/csv/read/options.rs index 8e668a1da06a..a1ad7ce0ebd5 100644 --- a/crates/polars-io/src/csv/read/options.rs +++ b/crates/polars-io/src/csv/read/options.rs @@ -37,6 +37,7 @@ pub struct CsvReadOptions { pub raise_if_empty: bool, pub ignore_errors: bool, pub fields_to_cast: Vec, + pub include_file_paths: Option, } #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -81,6 +82,7 @@ impl Default for CsvReadOptions { raise_if_empty: true, ignore_errors: false, fields_to_cast: vec![], + include_file_paths: None, } } } @@ -222,6 +224,12 @@ impl CsvReadOptions { self } + /// Include the path of the source file(s) as a column with this name, or don't include. + pub fn with_include_file_paths(mut self, include_file_paths: Option) -> Self { + self.include_file_paths = include_file_paths; + self + } + /// Continue with next batch when a ParserError is encountered. pub fn with_ignore_errors(mut self, ignore_errors: bool) -> Self { self.ignore_errors = ignore_errors; diff --git a/crates/polars-io/src/csv/read/reader.rs b/crates/polars-io/src/csv/read/reader.rs index 098f27d82c0d..ef68252a35e3 100644 --- a/crates/polars-io/src/csv/read/reader.rs +++ b/crates/polars-io/src/csv/read/reader.rs @@ -279,6 +279,33 @@ where let mut csv_reader = self.core_reader()?; let mut df = csv_reader.as_df()?; + if let Some(col) = &self.options.include_file_paths { + // TODO: fix this - handle "open-file" vs "in-mem" - see `to_include_path_name` + let name = self + .options + .path + .as_ref() + .and_then(|path| path.to_str()) + .unwrap_or("not a file"); + + if df.get_column_index(col).is_some() { + polars_bail!( + Duplicate: r#"column name for file paths "{}" conflicts with column name from file"#, + col + ); + } + + // TODO: add safety comment + // SAFETY: + unsafe { + df.with_column_unchecked(Column::new_scalar( + col.clone(), + Scalar::new(DataType::String, AnyValue::StringOwned(name.into())), + df.height(), + )); + } + } + // Important that this rechunk is never done in parallel. // As that leads to great memory overhead. if rechunk && df.n_chunks() > 1 { diff --git a/crates/polars-python/src/dataframe/io.rs b/crates/polars-python/src/dataframe/io.rs index 9b34eb7e8ae9..aee0eb911b93 100644 --- a/crates/polars-python/src/dataframe/io.rs +++ b/crates/polars-python/src/dataframe/io.rs @@ -32,7 +32,7 @@ impl PyDataFrame { skip_rows, projection, separator, rechunk, columns, encoding, n_threads, path, overwrite_dtype, overwrite_dtype_slice, low_memory, comment_prefix, quote_char, null_values, missing_utf8_is_empty_string, try_parse_dates, skip_rows_after_header, - row_index, eol_char, raise_if_empty, truncate_ragged_lines, decimal_comma, schema) + row_index, eol_char, raise_if_empty, truncate_ragged_lines, decimal_comma, schema, include_file_paths) )] pub fn read_csv( py: Python, @@ -65,6 +65,7 @@ impl PyDataFrame { truncate_ragged_lines: bool, decimal_comma: bool, schema: Option>, + include_file_paths: Option, ) -> PyResult { let null_values = null_values.map(|w| w.0); let eol_char = eol_char.as_bytes()[0]; @@ -113,6 +114,7 @@ impl PyDataFrame { .with_skip_rows_after_header(skip_rows_after_header) .with_row_index(row_index) .with_raise_if_empty(raise_if_empty) + .with_include_file_paths(include_file_paths.map(|x| x.into())) .with_parse_options( CsvParseOptions::default() .with_separator(separator.as_bytes()[0]) diff --git a/py-polars/polars/io/csv/functions.py b/py-polars/polars/io/csv/functions.py index 7ded05836f90..afd40b0c43fc 100644 --- a/py-polars/polars/io/csv/functions.py +++ b/py-polars/polars/io/csv/functions.py @@ -77,6 +77,7 @@ def read_csv( truncate_ragged_lines: bool = False, decimal_comma: bool = False, glob: bool = True, + include_file_paths: str | None = None, ) -> DataFrame: r""" Read a CSV file into a DataFrame. @@ -208,6 +209,8 @@ def read_csv( Parse floats using a comma as the decimal separator instead of a period. glob Expand path given via globbing rules. + include_file_paths + Include the path of the source file(s) as a column with this name. Returns ------- @@ -486,6 +489,7 @@ def read_csv( truncate_ragged_lines=truncate_ragged_lines, decimal_comma=decimal_comma, glob=glob, + include_file_paths=include_file_paths, ) if columns: @@ -532,6 +536,7 @@ def read_csv( truncate_ragged_lines=truncate_ragged_lines, decimal_comma=decimal_comma, glob=glob, + include_file_paths=include_file_paths, ) if new_columns: @@ -570,6 +575,7 @@ def _read_csv_impl( truncate_ragged_lines: bool = False, decimal_comma: bool = False, glob: bool = True, + include_file_paths: str | None = None, ) -> DataFrame: path: str | None if isinstance(source, (str, Path)): @@ -634,6 +640,7 @@ def _read_csv_impl( truncate_ragged_lines=truncate_ragged_lines, decimal_comma=decimal_comma, glob=glob, + include_file_paths=include_file_paths, ) if columns is None: return scan.collect() @@ -678,6 +685,7 @@ def _read_csv_impl( truncate_ragged_lines=truncate_ragged_lines, decimal_comma=decimal_comma, schema=schema, + include_file_paths=include_file_paths, ) return wrap_df(pydf) diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index 640e5418555c..404da3fb0c1f 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -7,6 +7,7 @@ import zlib from datetime import date, datetime, time, timedelta, timezone from decimal import Decimal as D +from pathlib import Path from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, TypedDict @@ -22,8 +23,6 @@ from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: - from pathlib import Path - from polars._typing import TimeUnit from tests.unit.conftest import MemoryUsage @@ -2299,3 +2298,34 @@ def test_read_csv_cast_unparsable_later( df.write_csv(f) f.seek(0) assert df.equals(pl.read_csv(f, schema={"x": dtype})) + + +@pytest.mark.write_disk +@pytest.mark.parametrize(("number_of_files"), [1, 2]) +def test_read_csv_include_file_name(tmp_path: Path, number_of_files: int) -> None: + tmp_path.mkdir(exist_ok=True) + dfs: list[pl.DataFrame] = [] + + for x in ["1", "2"][:number_of_files]: + path = Path(f"{tmp_path}/{x}.csv").absolute() + dfs.append(pl.DataFrame({"x": 10 * [x]}).with_columns(path=pl.lit(str(path)))) + dfs[-1].drop("path").write_csv(path) + + expected = pl.concat(dfs) + assert expected.columns == ["x", "path"] + + if number_of_files == 1: + read_csv_path = f"{tmp_path}/1.csv" + else: + read_csv_path = f"{tmp_path}/*.csv" + + with pytest.raises( + pl.exceptions.DuplicateError, + match=r'column name for file paths "x" conflicts with column name from file', + ): + pl.read_csv(read_csv_path, include_file_paths="x") + + res = pl.read_csv( + read_csv_path, include_file_paths="path", schema=expected.drop("path").schema + ) + assert_frame_equal(res, expected)