Skip to content

Commit

Permalink
fix: parquet, add row_count to empty file materialization (pola-rs#12310
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ritchie46 authored Nov 8, 2023
1 parent fa2b82a commit e0e6158
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
15 changes: 10 additions & 5 deletions crates/polars-io/src/parquet/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::parquet::async_impl::FetchRowGroupsFromObjectStore;
#[cfg(feature = "cloud")]
use crate::parquet::async_impl::ParquetObjectStore;
pub use crate::parquet::read_impl::BatchedParquetReader;
use crate::parquet::read_impl::{materialize_hive_partitions, read_parquet};
use crate::parquet::read_impl::{materialize_empty_df, read_parquet};
use crate::predicates::PhysicalIoExpr;
use crate::prelude::*;
use crate::RowCount;
Expand Down Expand Up @@ -365,8 +365,10 @@ impl ParquetAsyncReader {
pub async fn finish(mut self) -> PolarsResult<DataFrame> {
let rechunk = self.rechunk;
let metadata = self.get_metadata().await?.clone();
let schema = self.schema().await?;
let reader_schema = self.schema().await?;
let row_count = self.row_count.clone();
let hive_partition_columns = self.hive_partition_columns.clone();
let projection = self.projection.clone();

// batched reader deals with slice pushdown
let reader = self.batched(usize::MAX).await?;
Expand All @@ -378,9 +380,12 @@ impl ParquetAsyncReader {
chunks.push(result?)
}
if chunks.is_empty() {
let mut df = DataFrame::from(schema.as_ref());
materialize_hive_partitions(&mut df, hive_partition_columns.as_deref(), 0);
return Ok(df);
return Ok(materialize_empty_df(
projection.as_deref(),
reader_schema.as_ref(),
hive_partition_columns.as_deref(),
row_count.as_ref(),
));
}
let mut df = accumulate_dataframes_vertical_unchecked(chunks);

Expand Down
15 changes: 13 additions & 2 deletions crates/polars-io/src/parquet/read_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,18 +328,26 @@ fn rg_to_dfs_par_over_rg(
Ok(dfs.into_iter().flatten().collect())
}

fn materialize_empty_df(
pub(super) fn materialize_empty_df(
projection: Option<&[usize]>,
reader_schema: &ArrowSchema,
hive_partition_columns: Option<&[Series]>,
row_count: Option<&RowCount>,
) -> DataFrame {
let schema = if let Some(projection) = projection {
Cow::Owned(apply_projection(reader_schema, projection))
} else {
Cow::Borrowed(reader_schema)
};
let mut df = DataFrame::from(schema.as_ref());

if let Some(row_count) = row_count {
df.insert_at_idx(0, Series::new_empty(&row_count.name, &IDX_DTYPE))
.unwrap();
}

materialize_hive_partitions(&mut df, hive_partition_columns, 0);

df
}

Expand All @@ -362,6 +370,7 @@ pub fn read_parquet<R: MmapBytesReader>(
projection,
reader_schema,
hive_partition_columns,
row_count.as_ref(),
));
}

Expand Down Expand Up @@ -416,7 +425,7 @@ pub fn read_parquet<R: MmapBytesReader>(
&file_metadata,
reader_schema,
predicate,
row_count,
row_count.clone(),
parallel,
&materialized_projection,
use_statistics,
Expand All @@ -428,6 +437,7 @@ pub fn read_parquet<R: MmapBytesReader>(
projection,
reader_schema,
hive_partition_columns,
row_count.as_ref(),
))
} else {
accumulate_dataframes_vertical(dfs)
Expand Down Expand Up @@ -633,6 +643,7 @@ impl BatchedParquetReader {
Some(&self.projection),
self.schema.as_ref(),
self.hive_partition_columns.as_deref(),
self.row_count.as_ref(),
)]));
}

Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/unit/io/test_lazy_parquet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections import OrderedDict
from pathlib import Path
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -425,3 +426,13 @@ def test_parquet_many_row_groups_12297(tmp_path: Path) -> None:
df = pl.DataFrame({"x": range(100)})
df.write_parquet(file_path, row_group_size=5, use_pyarrow=True)
assert_frame_equal(pl.scan_parquet(file_path).collect(), df)


@pytest.mark.write_disk()
def test_row_count_empty_file(tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)
file_path = tmp_path / "test.parquet"
pl.DataFrame({"a": []}).write_parquet(file_path)
assert pl.scan_parquet(file_path).with_row_count(
"idx"
).collect().schema == OrderedDict([("idx", pl.UInt32), ("a", pl.Float32)])

0 comments on commit e0e6158

Please sign in to comment.