diff --git a/UPGRADE.md b/UPGRADE.md index 393e23e2..8c4f5546 100644 --- a/UPGRADE.md +++ b/UPGRADE.md @@ -19,8 +19,9 @@ This document is intended to simplify upgrading to newer versions by extending t See also [enif\_send](https://www.erlang.org/doc/man/erl_nif.html#enif_send). -3. As `Term::get_type` is now implemented using `enif_get_type`, some cases of - the `TermType` `enum` are changed, removed, or added: +3. As `Term::get_type` is now implemented using `enif_get_type` on all + non-Windows systems, some cases of the `TermType` `enum` are changed, + removed, or added (on all systems): 1. `EmptyList` is dropped, `List` is returned for both empty and non-empty lists 2. `Exception` is dropped diff --git a/rustler/src/dynamic.rs b/rustler/src/dynamic.rs index ad73f888..19b0a98c 100644 --- a/rustler/src/dynamic.rs +++ b/rustler/src/dynamic.rs @@ -1,3 +1,5 @@ +use std::ffi::c_double; + #[cfg(feature = "nif_version_2_15")] use rustler_sys::ErlNifTermType; @@ -43,7 +45,7 @@ impl From for TermType { } pub fn get_type(term: Term) -> TermType { - if cfg!(nif_version_2_15) { + if cfg!(feature = "nif_version_2_15") && !cfg!(target_family = "windows") { term.get_erl_type().into() } else if term.is_atom() { TermType::Atom @@ -56,7 +58,11 @@ pub fn get_type(term: Term) -> TermType { } else if term.is_map() { TermType::Map } else if term.is_number() { - TermType::Float + if term.is_float() { + TermType::Float + } else { + TermType::Integer + } } else if term.is_pid() { TermType::Pid } else if term.is_port() { @@ -98,4 +104,15 @@ impl<'a> Term<'a> { impl_check!(is_port); impl_check!(is_ref); impl_check!(is_tuple); + + pub fn is_float(self) -> bool { + let mut val: c_double = 0.0; + unsafe { + rustler_sys::enif_get_double(self.get_env().as_c_arg(), self.as_c_arg(), &mut val) == 1 + } + } + + pub fn is_integer(self) -> bool { + self.is_number() && !self.is_float() + } } diff --git a/rustler/src/term.rs b/rustler/src/term.rs index 27bf5950..b0d81b4c 100644 --- a/rustler/src/term.rs +++ b/rustler/src/term.rs @@ -124,7 +124,7 @@ impl<'a> Term<'a> { #[cfg(feature = "nif_version_2_15")] pub fn get_erl_type(&self) -> rustler_sys::ErlNifTermType { - unsafe { rustler_sys::enif_term_type(self.env.as_c_arg(), &self.as_c_arg()) } + unsafe { rustler_sys::enif_term_type(self.env.as_c_arg(), self.as_c_arg()) } } } diff --git a/rustler_sys/build.rs b/rustler_sys/build.rs index f873b023..70bdfe81 100644 --- a/rustler_sys/build.rs +++ b/rustler_sys/build.rs @@ -848,7 +848,7 @@ fn build_api(b: &mut dyn ApiBuilder, opts: &GenerateOptions) { b.func( "ErlNifTermType", "enif_term_type", - "env: *mut ErlNifEnv, term: *const ERL_NIF_TERM", + "env: *mut ErlNifEnv, term: ERL_NIF_TERM", ); b.func("c_int", "enif_is_pid_undefined", "pid: *const ErlNifPid"); diff --git a/rustler_tests/lib/rustler_test.ex b/rustler_tests/lib/rustler_test.ex index 7de1d85a..1adaded7 100644 --- a/rustler_tests/lib/rustler_test.ex +++ b/rustler_tests/lib/rustler_test.ex @@ -47,6 +47,7 @@ defmodule RustlerTest do def term_cmp(_, _), do: err() def term_internal_hash(_, _), do: err() def term_phash2_hash(_), do: err() + def term_type(_term), do: err() def sum_map_values(_), do: err() def map_entries_sorted(_), do: err() diff --git a/rustler_tests/native/rustler_test/src/lib.rs b/rustler_tests/native/rustler_test/src/lib.rs index 6b0e4ece..efe2372e 100644 --- a/rustler_tests/native/rustler_test/src/lib.rs +++ b/rustler_tests/native/rustler_test/src/lib.rs @@ -30,6 +30,7 @@ rustler::init!( test_term::term_cmp, test_term::term_internal_hash, test_term::term_phash2_hash, + test_term::term_type, test_map::sum_map_values, test_map::map_entries_sorted, test_map::map_from_arrays, diff --git a/rustler_tests/native/rustler_test/src/test_term.rs b/rustler_tests/native/rustler_test/src/test_term.rs index a499d80f..1bafdbea 100644 --- a/rustler_tests/native/rustler_test/src/test_term.rs +++ b/rustler_tests/native/rustler_test/src/test_term.rs @@ -7,6 +7,19 @@ mod atoms { equal, less, greater, + // Term types + atom, + binary, + float, + fun, + integer, + list, + map, + pid, + port, + reference, + tuple, + unknown, } } @@ -40,3 +53,21 @@ pub fn term_internal_hash(term: Term, salt: u32) -> u32 { pub fn term_phash2_hash(term: Term) -> u32 { term.hash_phash2() } + +#[rustler::nif] +pub fn term_type(term: Term) -> Atom { + match term.get_type() { + rustler::TermType::Atom => atoms::atom(), + rustler::TermType::Binary => atoms::binary(), + rustler::TermType::Fun => atoms::fun(), + rustler::TermType::List => atoms::list(), + rustler::TermType::Map => atoms::map(), + rustler::TermType::Integer => atoms::integer(), + rustler::TermType::Float => atoms::float(), + rustler::TermType::Pid => atoms::pid(), + rustler::TermType::Port => atoms::port(), + rustler::TermType::Ref => atoms::reference(), + rustler::TermType::Tuple => atoms::tuple(), + rustler::TermType::Unknown => atoms::unknown(), + } +} diff --git a/rustler_tests/test/term_test.exs b/rustler_tests/test/term_test.exs index 99f663b4..35bbab9a 100644 --- a/rustler_tests/test/term_test.exs +++ b/rustler_tests/test/term_test.exs @@ -67,4 +67,17 @@ defmodule RustlerTest.TermTest do assert unique > 50 end + + test "term type" do + assert RustlerTest.term_type(:foo) == :atom + assert RustlerTest.term_type("foo") == :binary + assert RustlerTest.term_type(42.2) == :float + assert RustlerTest.term_type(42) == :integer + assert RustlerTest.term_type(%{}) == :map + assert RustlerTest.term_type([]) == :list + assert RustlerTest.term_type({:ok, 42}) == :tuple + assert RustlerTest.term_type(self()) == :pid + assert RustlerTest.term_type(& &1) == :fun + assert RustlerTest.term_type(make_ref()) == :reference + end end