diff --git a/lib/explorer/polars_backend/data_frame.ex b/lib/explorer/polars_backend/data_frame.ex index d9278f0e0..6749de175 100644 --- a/lib/explorer/polars_backend/data_frame.ex +++ b/lib/explorer/polars_backend/data_frame.ex @@ -570,7 +570,6 @@ defmodule Explorer.PolarsBackend.DataFrame do # Like `Explorer.Series.from_list/2`, but gives a better error message with the series name. defp series_from_list!(name, list, dtype) do type = Explorer.Shared.dtype_from_list!(list, dtype) - list = Explorer.Shared.cast_series(list, type) series = Shared.from_list(list, type, name) Explorer.Backend.Series.new(series, type) rescue diff --git a/lib/explorer/polars_backend/shared.ex b/lib/explorer/polars_backend/shared.ex index b5f4ac0c2..037d80d10 100644 --- a/lib/explorer/polars_backend/shared.ex +++ b/lib/explorer/polars_backend/shared.ex @@ -152,7 +152,7 @@ defmodule Explorer.PolarsBackend.Shared do row, columns -> Enum.reduce(row, columns, fn {field, value}, columns -> - Map.update!(columns, field, &[value | &1]) + Map.update!(columns, to_string(field), &[value | &1]) end) end) diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index 1bf611cd8..259f37f29 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -461,7 +461,6 @@ defmodule Explorer.Series do normalised_dtype = if opts[:dtype], do: Shared.normalise_dtype!(opts[:dtype]) type = Shared.dtype_from_list!(list, normalised_dtype) - list = Shared.cast_series(list, type) series = backend.from_list(list, type) diff --git a/lib/explorer/shared.ex b/lib/explorer/shared.ex index d179faf88..556beef91 100644 --- a/lib/explorer/shared.ex +++ b/lib/explorer/shared.ex @@ -389,37 +389,6 @@ defmodule Explorer.Shared do def leaf_dtype({:list, inner_dtype}), do: leaf_dtype(inner_dtype) def leaf_dtype(dtype), do: dtype - @doc """ - Downcasts lists of mixed numeric types (float and int) to float. - """ - def cast_series(list, {:struct, dtypes}) when is_list(list) do - Enum.map(list, fn - nil -> - nil - - item -> - Enum.map(item, fn {field, inner_value} -> - column = to_string(field) - {^column, inner_dtype} = List.keyfind!(dtypes, column, 0) - [casted_value] = cast_series([inner_value], inner_dtype) - {column, casted_value} - end) - end) - end - - def cast_series(list, {:list, inner_dtype}) when is_list(list) do - Enum.map(list, fn item -> cast_series(item, inner_dtype) end) - end - - def cast_series(list, {:f, _}) do - Enum.map(list, fn - item when item in [nil, :infinity, :neg_infinity, :nan] or is_float(item) -> item - item -> item / 1 - end) - end - - def cast_series(list, _), do: list - @doc """ Merge two dtypes. """ diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index cad789b3f..e74ed565a 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -63,6 +63,10 @@ macro_rules! from_list_float { .decode::()? .map(|item| match item.get_type() { TermType::Float => item.decode::>(), + TermType::Integer => { + let int_value = item.decode::().unwrap(); + Ok(Some(int_value as $type)) + } TermType::Atom => Ok(if nan.eq(&item) { Some($module::NAN) } else if infinity.eq(&item) {