Skip to content

Commit

Permalink
Fix encode of binaries to terms inside struct series (#996)
Browse files Browse the repository at this point in the history
Fixes #994
  • Loading branch information
philss authored Sep 26, 2024
1 parent ed18e53 commit 02c68a3
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 39 deletions.
73 changes: 36 additions & 37 deletions native/explorer/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,41 @@ pub fn resource_term_from_value<'b>(
AnyValue::Binary(v) => unsafe {
Ok(Some(resource.make_binary_unsafe(env, |_| v)).encode(env))
},
_ => term_from_value(v, env),
AnyValue::Null => Ok(atom::nil().to_term(env)),
AnyValue::Boolean(v) => Ok(v.encode(env)),
AnyValue::String(v) => Ok(v.encode(env)),
AnyValue::Int8(v) => Ok(v.encode(env)),
AnyValue::Int16(v) => Ok(v.encode(env)),
AnyValue::Int32(v) => Ok(v.encode(env)),
AnyValue::Int64(v) => Ok(v.encode(env)),
AnyValue::UInt8(v) => Ok(v.encode(env)),
AnyValue::UInt16(v) => Ok(v.encode(env)),
AnyValue::UInt32(v) => Ok(v.encode(env)),
AnyValue::UInt64(v) => Ok(v.encode(env)),
AnyValue::Float32(v) => Ok(term_from_float32(v, env)),
AnyValue::Float64(v) => Ok(term_from_float64(v, env)),
AnyValue::Date(v) => encode_date(v, env),
AnyValue::Time(v) => encode_time(v, env),
AnyValue::Datetime(v, time_unit, None) => encode_naive_datetime(v, time_unit, env),
AnyValue::Datetime(v, time_unit, Some(time_zone)) => {
encode_datetime(v, time_unit, time_zone.to_string(), env)
}
AnyValue::Duration(v, time_unit) => encode_duration(v, time_unit, env),
AnyValue::Categorical(idx, mapping, _) => Ok(mapping.get(idx).encode(env)),
AnyValue::List(series) => list_from_series(ExSeries::new(series), env),
AnyValue::Struct(_, _, fields) => v
._iter_struct_av()
.zip(fields)
.map(|(value, field)| {
Ok((
field.name.as_str(),
resource_term_from_value(resource, value, env)?,
))
})
.collect::<Result<HashMap<_, _>, ExplorerError>>()
.map(|map| map.encode(env)),
AnyValue::Decimal(number, scale) => encode_decimal(number, scale, env),
dt => panic!("cannot encode value {dt:?} to term"),
}
}

Expand All @@ -749,41 +783,6 @@ macro_rules! term_from_float {
term_from_float!(term_from_float64, f64);
term_from_float!(term_from_float32, f32);

pub fn term_from_value<'b>(v: AnyValue, env: Env<'b>) -> Result<Term<'b>, ExplorerError> {
match v {
AnyValue::Null => Ok(None::<bool>.encode(env)),
AnyValue::Boolean(v) => Ok(Some(v).encode(env)),
AnyValue::String(v) => Ok(Some(v).encode(env)),
AnyValue::Int8(v) => Ok(Some(v).encode(env)),
AnyValue::Int16(v) => Ok(Some(v).encode(env)),
AnyValue::Int32(v) => Ok(Some(v).encode(env)),
AnyValue::Int64(v) => Ok(Some(v).encode(env)),
AnyValue::UInt8(v) => Ok(Some(v).encode(env)),
AnyValue::UInt16(v) => Ok(Some(v).encode(env)),
AnyValue::UInt32(v) => Ok(Some(v).encode(env)),
AnyValue::UInt64(v) => Ok(Some(v).encode(env)),
AnyValue::Float32(v) => Ok(Some(term_from_float32(v, env)).encode(env)),
AnyValue::Float64(v) => Ok(Some(term_from_float64(v, env)).encode(env)),
AnyValue::Date(v) => encode_date(v, env),
AnyValue::Time(v) => encode_time(v, env),
AnyValue::Datetime(v, time_unit, None) => encode_naive_datetime(v, time_unit, env),
AnyValue::Datetime(v, time_unit, Some(time_zone)) => {
encode_datetime(v, time_unit, time_zone.to_string(), env)
}
AnyValue::Duration(v, time_unit) => encode_duration(v, time_unit, env),
AnyValue::Categorical(idx, mapping, _) => Ok(mapping.get(idx).encode(env)),
AnyValue::List(series) => list_from_series(ExSeries::new(series), env),
AnyValue::Struct(_, _, fields) => v
._iter_struct_av()
.zip(fields)
.map(|(value, field)| Ok((field.name.as_str(), term_from_value(value, env)?)))
.collect::<Result<HashMap<_, _>, ExplorerError>>()
.map(|map| map.encode(env)),
AnyValue::Decimal(number, scale) => encode_decimal(number, scale, env),
dt => panic!("cannot encode value {dt:?} to term"),
}
}

pub fn list_from_series(s: ExSeries, env: Env) -> Result<Term, ExplorerError> {
match s.dtype() {
DataType::Null => null_series_to_list(&s, env),
Expand Down Expand Up @@ -823,7 +822,7 @@ pub fn list_from_series(s: ExSeries, env: Env) -> Result<Term, ExplorerError> {
.map(|lists| lists.encode(env)),
DataType::Struct(_fields) => s
.iter()
.map(|value| term_from_value(value, env))
.map(|value| resource_term_from_value(&s.resource, value, env))
.collect::<Result<Vec<_>, ExplorerError>>()
.map(|values| values.encode(env)),
DataType::Decimal(_precision, _scale) => decimal_series_to_list(&s, env),
Expand Down
3 changes: 2 additions & 1 deletion native/explorer/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,8 @@ pub fn s_quantile<'a>(
.unwrap()
.encode(env)),
},
_ => encoding::term_from_value(
_ => encoding::resource_term_from_value(
&s.resource,
s.quantile_reduce(quantile, strategy)?
.into_series("quantile".into())
.cast(dtype)?
Expand Down
20 changes: 19 additions & 1 deletion test/explorer/series/struct_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ defmodule Explorer.Series.StructTest do
]
end

test "allows structs structs with special float values" do
test "allows structs with special float values" do
series = Series.from_list([%{a: :nan, b: :infinity, c: :neg_infinity}])

assert series.dtype == {:struct, [{"a", {:f, 64}}, {"b", {:f, 64}}, {"c", {:f, 64}}]}
Expand Down Expand Up @@ -126,6 +126,24 @@ defmodule Explorer.Series.StructTest do
assert series1.dtype == {:struct, [{"a", :string}, {"b", :string}]}
end

test "allow binaries inside structs" do
a1 = <<60, 15, 233, 144, 204, 179, 30, 148, 197, 140, 19, 96, 65, 131, 83, 195>>
a2 = <<176, 176, 30, 208, 242, 105, 50, 4, 26, 51, 170, 216, 121, 119, 214, 218>>
a3 = <<5, 77, 199, 172, 128, 195, 160, 241, 93, 61, 84, 116, 61, 141, 195, 130>>

s =
Series.from_list([%{a: a1}, %{a: a2}, %{a: a3}, %{a: nil}],
dtype: {:struct, [{"a", :binary}]}
)

assert s.dtype == {:struct, [{"a", :binary}]}

assert Series.to_list(s) == [%{"a" => a1}, %{"a" => a2}, %{"a" => a3}, %{"a" => nil}]

assert s[1] === %{"a" => a2}
assert s[3] === %{"a" => nil}
end

test "errors when structs have mismatched types" do
assert_raise ArgumentError,
"the value \"a\" does not match the inferred dtype {:s, 64}",
Expand Down

0 comments on commit 02c68a3

Please sign in to comment.