Skip to content

Commit

Permalink
Fix overwrite of dtypes for DF.load_csv/2
Browse files Browse the repository at this point in the history
Using the `with_schema` function does not correctly because expects to
receive all column names, and with the correct order. We are passing a
map to the backend, and then we transform that to a list before calling
the Rust code. So the schema could be out-of-order.

We are now using `with_schema_overwrite` for both `load_csv` and
`from_csv`.

Closes #953
  • Loading branch information
philss committed Aug 1, 2024
1 parent ac6bf30 commit 3fcf42f
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 12 deletions.
7 changes: 1 addition & 6 deletions lib/explorer/polars_backend/data_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,6 @@ defmodule Explorer.PolarsBackend.DataFrame do

{columns, with_projection} = column_names_or_projection(columns)

dtypes_list =
if not Enum.empty?(dtypes) do
Map.to_list(dtypes)
end

df =
Native.df_load_csv(
contents,
Expand All @@ -212,7 +207,7 @@ defmodule Explorer.PolarsBackend.DataFrame do
delimiter,
true,
columns,
dtypes_list,
Map.to_list(dtypes),
encoding,
nil_values,
parse_dates,
Expand Down
18 changes: 12 additions & 6 deletions native/explorer/src/dataframe/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,13 @@ pub fn df_from_csv(
_ => CsvEncoding::Utf8,
};

let dataframe = CsvReadOptions::default()
let read_options = if dtypes.is_empty() {
CsvReadOptions::default()
} else {
CsvReadOptions::default().with_schema_overwrite(Some(schema_from_dtypes_pairs(dtypes)?))
};

let dataframe = read_options
.with_infer_schema_length(infer_schema_length)
.with_has_header(has_header)
.with_n_rows(stop_after_n_rows)
Expand All @@ -56,7 +62,6 @@ pub fn df_from_csv(
.with_projection(projection.map(Arc::new))
.with_rechunk(do_rechunk)
.with_columns(column_names.map(Arc::from))
.with_schema_overwrite(Some(schema_from_dtypes_pairs(dtypes)?))
.with_parse_options(
CsvParseOptions::default()
.with_encoding(encoding)
Expand Down Expand Up @@ -152,7 +157,7 @@ pub fn df_load_csv(
delimiter_as_byte: u8,
do_rechunk: bool,
column_names: Option<Vec<String>>,
dtypes: Option<Vec<(&str, ExSeriesDtype)>>,
dtypes: Vec<(&str, ExSeriesDtype)>,
encoding: &str,
null_vals: Vec<String>,
parse_dates: bool,
Expand All @@ -165,9 +170,10 @@ pub fn df_load_csv(

let cursor = Cursor::new(binary.as_slice());

let read_options = match dtypes {
Some(val) => CsvReadOptions::default().with_schema(Some(schema_from_dtypes_pairs(val)?)),
None => CsvReadOptions::default(),
let read_options = if dtypes.is_empty() {
CsvReadOptions::default()
} else {
CsvReadOptions::default().with_schema_overwrite(Some(schema_from_dtypes_pairs(dtypes)?))
};

let dataframe = read_options
Expand Down
249 changes: 249 additions & 0 deletions test/explorer/data_frame/csv_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,130 @@ defmodule Explorer.DataFrame.CSVTest do
assert city[13] == "Aberdeen, Aberdeen City, UK"
end

test "load_csv/2 dtypes - all as strings" do
csv =
"""
id,first_name,last_name,email,gender,ip_address,salary,latitude,longitude
1,Torey,Geraghty,[email protected],Male,119.110.38.172,14036.68,38.9187037,-76.9611991
2,Nevin,Mandrake,[email protected],Male,161.2.124.233,32530.27,41.4176872,-8.7653155
3,Melisenda,Guiso,[email protected],Female,192.152.64.134,9177.8,21.3772424,110.2485736
4,Noble,Doggett,[email protected],Male,252.234.29.244,20328.76,37.268428,55.1487513
5,Janaya,Claypoole,[email protected],Female,150.191.214.252,21442.93,15.3553417,120.5293228
6,Sarah,Hugk,[email protected],Female,211.158.246.13,79709.16,28.168408,120.482198
7,Ulberto,Simenon,[email protected],Male,206.56.108.90,16248.98,48.4046776,-0.9746208
8,Kevon,Lingner,[email protected],Male,181.71.212.116,7497.64,-23.351784,-47.6931718
9,Sada,Garbert,[email protected],Female,170.42.190.231,15969.95,30.3414125,114.1543243
10,Salmon,Shoulders,[email protected],Male,68.138.106.143,19996.71,49.2152833,17.7687416
"""

headers = ~w(id first_name last_name email gender ip_address salary latitude longitude)

# Out of order on purpose.
df = DF.load_csv!(csv, dtypes: for(l <- Enum.shuffle(headers), do: {l, :string}))

assert DF.names(df) == headers

assert DF.to_columns(df, atom_keys: true) == %{
email: [
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]"
],
first_name: [
"Torey",
"Nevin",
"Melisenda",
"Noble",
"Janaya",
"Sarah",
"Ulberto",
"Kevon",
"Sada",
"Salmon"
],
gender: [
"Male",
"Male",
"Female",
"Male",
"Female",
"Female",
"Male",
"Male",
"Female",
"Male"
],
id: ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
ip_address: [
"119.110.38.172",
"161.2.124.233",
"192.152.64.134",
"252.234.29.244",
"150.191.214.252",
"211.158.246.13",
"206.56.108.90",
"181.71.212.116",
"170.42.190.231",
"68.138.106.143"
],
last_name: [
"Geraghty",
"Mandrake",
"Guiso",
"Doggett",
"Claypoole",
"Hugk",
"Simenon",
"Lingner",
"Garbert",
"Shoulders"
],
latitude: [
"38.9187037",
"41.4176872",
"21.3772424",
"37.268428",
"15.3553417",
"28.168408",
"48.4046776",
"-23.351784",
"30.3414125",
"49.2152833"
],
longitude: [
"-76.9611991",
"-8.7653155",
"110.2485736",
"55.1487513",
"120.5293228",
"120.482198",
"-0.9746208",
"-47.6931718",
"114.1543243",
"17.7687416"
],
salary: [
"14036.68",
"32530.27",
"9177.8",
"20328.76",
"21442.93",
"79709.16",
"16248.98",
"7497.64",
"15969.95",
"19996.71"
]
}
end

def assert_csv(type, csv_value, parsed_value, from_csv_options) do
data = "column\n#{csv_value}\n"
# parsing should work as expected
Expand Down Expand Up @@ -182,6 +306,131 @@ defmodule Explorer.DataFrame.CSVTest do
}
end

@tag :tmp_dir
test "dtypes - all as strings", config do
csv =
tmp_csv(config.tmp_dir, """
id,first_name,last_name,email,gender,ip_address,salary,latitude,longitude
1,Torey,Geraghty,[email protected],Male,119.110.38.172,14036.68,38.9187037,-76.9611991
2,Nevin,Mandrake,[email protected],Male,161.2.124.233,32530.27,41.4176872,-8.7653155
3,Melisenda,Guiso,[email protected],Female,192.152.64.134,9177.8,21.3772424,110.2485736
4,Noble,Doggett,[email protected],Male,252.234.29.244,20328.76,37.268428,55.1487513
5,Janaya,Claypoole,[email protected],Female,150.191.214.252,21442.93,15.3553417,120.5293228
6,Sarah,Hugk,[email protected],Female,211.158.246.13,79709.16,28.168408,120.482198
7,Ulberto,Simenon,[email protected],Male,206.56.108.90,16248.98,48.4046776,-0.9746208
8,Kevon,Lingner,[email protected],Male,181.71.212.116,7497.64,-23.351784,-47.6931718
9,Sada,Garbert,[email protected],Female,170.42.190.231,15969.95,30.3414125,114.1543243
10,Salmon,Shoulders,[email protected],Male,68.138.106.143,19996.71,49.2152833,17.7687416
""")

headers = ~w(id first_name last_name email gender ip_address salary latitude longitude)

# Out of order on purpose.
df = DF.from_csv!(csv, dtypes: for(l <- Enum.shuffle(headers), do: {l, :string}))

assert DF.names(df) == headers

assert DF.to_columns(df, atom_keys: true) == %{
email: [
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]"
],
first_name: [
"Torey",
"Nevin",
"Melisenda",
"Noble",
"Janaya",
"Sarah",
"Ulberto",
"Kevon",
"Sada",
"Salmon"
],
gender: [
"Male",
"Male",
"Female",
"Male",
"Female",
"Female",
"Male",
"Male",
"Female",
"Male"
],
id: ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
ip_address: [
"119.110.38.172",
"161.2.124.233",
"192.152.64.134",
"252.234.29.244",
"150.191.214.252",
"211.158.246.13",
"206.56.108.90",
"181.71.212.116",
"170.42.190.231",
"68.138.106.143"
],
last_name: [
"Geraghty",
"Mandrake",
"Guiso",
"Doggett",
"Claypoole",
"Hugk",
"Simenon",
"Lingner",
"Garbert",
"Shoulders"
],
latitude: [
"38.9187037",
"41.4176872",
"21.3772424",
"37.268428",
"15.3553417",
"28.168408",
"48.4046776",
"-23.351784",
"30.3414125",
"49.2152833"
],
longitude: [
"-76.9611991",
"-8.7653155",
"110.2485736",
"55.1487513",
"120.5293228",
"120.482198",
"-0.9746208",
"-47.6931718",
"114.1543243",
"17.7687416"
],
salary: [
"14036.68",
"32530.27",
"9177.8",
"20328.76",
"21442.93",
"79709.16",
"16248.98",
"7497.64",
"15969.95",
"19996.71"
]
}
end

@tag :tmp_dir
test "dtypes - parse datetime", config do
csv =
Expand Down

0 comments on commit 3fcf42f

Please sign in to comment.