diff --git a/native/explorer/src/encoding.rs b/native/explorer/src/encoding.rs index 755003bde..a63353c3c 100644 --- a/native/explorer/src/encoding.rs +++ b/native/explorer/src/encoding.rs @@ -725,7 +725,41 @@ pub fn resource_term_from_value<'b>( AnyValue::Binary(v) => unsafe { Ok(Some(resource.make_binary_unsafe(env, |_| v)).encode(env)) }, - _ => term_from_value(v, env), + AnyValue::Null => Ok(atom::nil().to_term(env)), + AnyValue::Boolean(v) => Ok(v.encode(env)), + AnyValue::String(v) => Ok(v.encode(env)), + AnyValue::Int8(v) => Ok(v.encode(env)), + AnyValue::Int16(v) => Ok(v.encode(env)), + AnyValue::Int32(v) => Ok(v.encode(env)), + AnyValue::Int64(v) => Ok(v.encode(env)), + AnyValue::UInt8(v) => Ok(v.encode(env)), + AnyValue::UInt16(v) => Ok(v.encode(env)), + AnyValue::UInt32(v) => Ok(v.encode(env)), + AnyValue::UInt64(v) => Ok(v.encode(env)), + AnyValue::Float32(v) => Ok(term_from_float32(v, env)), + AnyValue::Float64(v) => Ok(term_from_float64(v, env)), + AnyValue::Date(v) => encode_date(v, env), + AnyValue::Time(v) => encode_time(v, env), + AnyValue::Datetime(v, time_unit, None) => encode_naive_datetime(v, time_unit, env), + AnyValue::Datetime(v, time_unit, Some(time_zone)) => { + encode_datetime(v, time_unit, time_zone.to_string(), env) + } + AnyValue::Duration(v, time_unit) => encode_duration(v, time_unit, env), + AnyValue::Categorical(idx, mapping, _) => Ok(mapping.get(idx).encode(env)), + AnyValue::List(series) => list_from_series(ExSeries::new(series), env), + AnyValue::Struct(_, _, fields) => v + ._iter_struct_av() + .zip(fields) + .map(|(value, field)| { + Ok(( + field.name.as_str(), + resource_term_from_value(resource, value, env)?, + )) + }) + .collect::, ExplorerError>>() + .map(|map| map.encode(env)), + AnyValue::Decimal(number, scale) => encode_decimal(number, scale, env), + dt => panic!("cannot encode value {dt:?} to term"), } } @@ -749,41 +783,6 @@ macro_rules! term_from_float { term_from_float!(term_from_float64, f64); term_from_float!(term_from_float32, f32); -pub fn term_from_value<'b>(v: AnyValue, env: Env<'b>) -> Result, ExplorerError> { - match v { - AnyValue::Null => Ok(None::.encode(env)), - AnyValue::Boolean(v) => Ok(Some(v).encode(env)), - AnyValue::String(v) => Ok(Some(v).encode(env)), - AnyValue::Int8(v) => Ok(Some(v).encode(env)), - AnyValue::Int16(v) => Ok(Some(v).encode(env)), - AnyValue::Int32(v) => Ok(Some(v).encode(env)), - AnyValue::Int64(v) => Ok(Some(v).encode(env)), - AnyValue::UInt8(v) => Ok(Some(v).encode(env)), - AnyValue::UInt16(v) => Ok(Some(v).encode(env)), - AnyValue::UInt32(v) => Ok(Some(v).encode(env)), - AnyValue::UInt64(v) => Ok(Some(v).encode(env)), - AnyValue::Float32(v) => Ok(Some(term_from_float32(v, env)).encode(env)), - AnyValue::Float64(v) => Ok(Some(term_from_float64(v, env)).encode(env)), - AnyValue::Date(v) => encode_date(v, env), - AnyValue::Time(v) => encode_time(v, env), - AnyValue::Datetime(v, time_unit, None) => encode_naive_datetime(v, time_unit, env), - AnyValue::Datetime(v, time_unit, Some(time_zone)) => { - encode_datetime(v, time_unit, time_zone.to_string(), env) - } - AnyValue::Duration(v, time_unit) => encode_duration(v, time_unit, env), - AnyValue::Categorical(idx, mapping, _) => Ok(mapping.get(idx).encode(env)), - AnyValue::List(series) => list_from_series(ExSeries::new(series), env), - AnyValue::Struct(_, _, fields) => v - ._iter_struct_av() - .zip(fields) - .map(|(value, field)| Ok((field.name.as_str(), term_from_value(value, env)?))) - .collect::, ExplorerError>>() - .map(|map| map.encode(env)), - AnyValue::Decimal(number, scale) => encode_decimal(number, scale, env), - dt => panic!("cannot encode value {dt:?} to term"), - } -} - pub fn list_from_series(s: ExSeries, env: Env) -> Result { match s.dtype() { DataType::Null => null_series_to_list(&s, env), @@ -823,7 +822,7 @@ pub fn list_from_series(s: ExSeries, env: Env) -> Result { .map(|lists| lists.encode(env)), DataType::Struct(_fields) => s .iter() - .map(|value| term_from_value(value, env)) + .map(|value| resource_term_from_value(&s.resource, value, env)) .collect::, ExplorerError>>() .map(|values| values.encode(env)), DataType::Decimal(_precision, _scale) => decimal_series_to_list(&s, env), diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index ad4b9aa9b..5b66ebc11 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -1015,7 +1015,8 @@ pub fn s_quantile<'a>( .unwrap() .encode(env)), }, - _ => encoding::term_from_value( + _ => encoding::resource_term_from_value( + &s.resource, s.quantile_reduce(quantile, strategy)? .into_series("quantile".into()) .cast(dtype)? diff --git a/test/explorer/series/struct_test.exs b/test/explorer/series/struct_test.exs index d7dd77be1..c0c2927ad 100644 --- a/test/explorer/series/struct_test.exs +++ b/test/explorer/series/struct_test.exs @@ -74,7 +74,7 @@ defmodule Explorer.Series.StructTest do ] end - test "allows structs structs with special float values" do + test "allows structs with special float values" do series = Series.from_list([%{a: :nan, b: :infinity, c: :neg_infinity}]) assert series.dtype == {:struct, [{"a", {:f, 64}}, {"b", {:f, 64}}, {"c", {:f, 64}}]} @@ -126,6 +126,24 @@ defmodule Explorer.Series.StructTest do assert series1.dtype == {:struct, [{"a", :string}, {"b", :string}]} end + test "allow binaries inside structs" do + a1 = <<60, 15, 233, 144, 204, 179, 30, 148, 197, 140, 19, 96, 65, 131, 83, 195>> + a2 = <<176, 176, 30, 208, 242, 105, 50, 4, 26, 51, 170, 216, 121, 119, 214, 218>> + a3 = <<5, 77, 199, 172, 128, 195, 160, 241, 93, 61, 84, 116, 61, 141, 195, 130>> + + s = + Series.from_list([%{a: a1}, %{a: a2}, %{a: a3}, %{a: nil}], + dtype: {:struct, [{"a", :binary}]} + ) + + assert s.dtype == {:struct, [{"a", :binary}]} + + assert Series.to_list(s) == [%{"a" => a1}, %{"a" => a2}, %{"a" => a3}, %{"a" => nil}] + + assert s[1] === %{"a" => a2} + assert s[3] === %{"a" => nil} + end + test "errors when structs have mismatched types" do assert_raise ArgumentError, "the value \"a\" does not match the inferred dtype {:s, 64}",