Skip to content

Commit

Permalink
Move float/struct casting to backend
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Jun 8, 2024
1 parent b2a6b53 commit c5c2eb7
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 34 deletions.
1 change: 0 additions & 1 deletion lib/explorer/polars_backend/data_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,6 @@ defmodule Explorer.PolarsBackend.DataFrame do
# Like `Explorer.Series.from_list/2`, but gives a better error message with the series name.
defp series_from_list!(name, list, dtype) do
type = Explorer.Shared.dtype_from_list!(list, dtype)
list = Explorer.Shared.cast_series(list, type)
series = Shared.from_list(list, type, name)
Explorer.Backend.Series.new(series, type)
rescue
Expand Down
2 changes: 1 addition & 1 deletion lib/explorer/polars_backend/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ defmodule Explorer.PolarsBackend.Shared do

row, columns ->
Enum.reduce(row, columns, fn {field, value}, columns ->
Map.update!(columns, field, &[value | &1])
Map.update!(columns, to_string(field), &[value | &1])
end)
end)

Expand Down
1 change: 0 additions & 1 deletion lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,6 @@ defmodule Explorer.Series do
normalised_dtype = if opts[:dtype], do: Shared.normalise_dtype!(opts[:dtype])

type = Shared.dtype_from_list!(list, normalised_dtype)
list = Shared.cast_series(list, type)

series = backend.from_list(list, type)

Expand Down
31 changes: 0 additions & 31 deletions lib/explorer/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -389,37 +389,6 @@ defmodule Explorer.Shared do
def leaf_dtype({:list, inner_dtype}), do: leaf_dtype(inner_dtype)
def leaf_dtype(dtype), do: dtype

@doc """
Downcasts lists of mixed numeric types (float and int) to float.
"""
def cast_series(list, {:struct, dtypes}) when is_list(list) do
Enum.map(list, fn
nil ->
nil

item ->
Enum.map(item, fn {field, inner_value} ->
column = to_string(field)
{^column, inner_dtype} = List.keyfind!(dtypes, column, 0)
[casted_value] = cast_series([inner_value], inner_dtype)
{column, casted_value}
end)
end)
end

def cast_series(list, {:list, inner_dtype}) when is_list(list) do
Enum.map(list, fn item -> cast_series(item, inner_dtype) end)
end

def cast_series(list, {:f, _}) do
Enum.map(list, fn
item when item in [nil, :infinity, :neg_infinity, :nan] or is_float(item) -> item
item -> item / 1
end)
end

def cast_series(list, _), do: list

@doc """
Merge two dtypes.
"""
Expand Down
4 changes: 4 additions & 0 deletions native/explorer/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ macro_rules! from_list_float {
.decode::<ListIterator>()?
.map(|item| match item.get_type() {
TermType::Float => item.decode::<Option<$type>>(),
TermType::Integer => {
let int_value = item.decode::<i64>().unwrap();
Ok(Some(int_value as $type))
}
TermType::Atom => Ok(if nan.eq(&item) {
Some($module::NAN)
} else if infinity.eq(&item) {
Expand Down

0 comments on commit c5c2eb7

Please sign in to comment.