From bbe3dc08ca2ac69b57e008fbd5e15ac435a8dde8 Mon Sep 17 00:00:00 2001 From: "Peter M. Stahl" Date: Thu, 5 Dec 2024 12:00:31 +0100 Subject: [PATCH] Convert string representations of languages and ISO codes to actual types in Python bindings (#411) --- lingua.pyi | 27 +++++++++++++++++ src/isocode.rs | 20 ++++++++---- src/language.rs | 13 +++++--- src/python.rs | 57 ++++++++++++++++++++++++++++++++--- tests/python/test_language.py | 26 ++++++++++++++++ 5 files changed, 127 insertions(+), 16 deletions(-) diff --git a/lingua.pyi b/lingua.pyi index 3f6deda7..0729920c 100644 --- a/lingua.pyi +++ b/lingua.pyi @@ -198,6 +198,15 @@ class Language(Enum): ValueError: if there is no language for the given ISO code """ + @classmethod + def from_str(cls, string: str) -> "Language": + """Return the language associated with the string representation + passed to this method. + + Raises: + ValueError: if there is no language for the given string representation + """ + class IsoCode639_1(Enum): """This enum specifies the ISO 639-1 code representations for the @@ -282,6 +291,15 @@ class IsoCode639_1(Enum): ZH = 74 ZU = 75 + @classmethod + def from_str(cls, string: str) -> "Language": + """Return the ISO 639-1 code associated with the string representation + passed to this method. + + Raises: + ValueError: if there is no ISO 639-1 code for the given string representation + """ + class IsoCode639_3(Enum): """This enum specifies the ISO 639-3 code representations for the @@ -366,6 +384,15 @@ class IsoCode639_3(Enum): ZHO = 74 ZUL = 75 + @classmethod + def from_str(cls, string: str) -> "Language": + """Return the ISO 639-3 code associated with the string representation + passed to this method. + + Raises: + ValueError: if there is no ISO 639-3 code for the given string representation + """ + class LanguageDetector: """This class detects the language of text.""" diff --git a/src/isocode.rs b/src/isocode.rs index 059fc67f..f6258a19 100644 --- a/src/isocode.rs +++ b/src/isocode.rs @@ -38,7 +38,10 @@ use strum_macros::{EnumIter, EnumString}; )] #[allow(clippy::upper_case_acronyms)] #[strum(ascii_case_insensitive)] -#[cfg_attr(feature = "python", pyo3::prelude::pyclass(eq, eq_int, frozen, hash, ord))] +#[cfg_attr( + feature = "python", + pyo3::prelude::pyclass(eq, eq_int, frozen, hash, ord) +)] pub enum IsoCode639_1 { #[cfg(feature = "afrikaans")] /// The ISO 639-1 code for [`Afrikaans`](crate::language::Language::Afrikaans) @@ -360,7 +363,10 @@ pub enum IsoCode639_1 { )] #[allow(clippy::upper_case_acronyms)] #[strum(ascii_case_insensitive)] -#[cfg_attr(feature = "python", pyo3::prelude::pyclass(eq, eq_int, frozen, hash, ord))] +#[cfg_attr( + feature = "python", + pyo3::prelude::pyclass(eq, eq_int, frozen, hash, ord) +)] pub enum IsoCode639_3 { #[cfg(feature = "afrikaans")] /// The ISO 639-3 code for [`Afrikaans`](crate::language::Language::Afrikaans) @@ -679,9 +685,9 @@ impl Display for IsoCode639_3 { #[cfg(test)] mod tests { - use std::str::FromStr; - use super::*; + use std::str::FromStr; + use strum::ParseError::VariantNotFound; #[test] fn assert_iso_code_639_1_string_representation_is_correct() { @@ -695,11 +701,13 @@ mod tests { #[test] fn assert_string_to_iso_code_639_1_is_correct() { - assert_eq!(IsoCode639_1::from_str("en").unwrap(), IsoCode639_1::EN); + assert_eq!(IsoCode639_1::from_str("en"), Ok(IsoCode639_1::EN)); + assert_eq!(IsoCode639_1::from_str("12"), Err(VariantNotFound)); } #[test] fn assert_string_to_iso_code_639_3_is_correct() { - assert_eq!(IsoCode639_3::from_str("eng").unwrap(), IsoCode639_3::ENG); + assert_eq!(IsoCode639_3::from_str("eng"), Ok(IsoCode639_3::ENG)); + assert_eq!(IsoCode639_3::from_str("123"), Err(VariantNotFound)); } } diff --git a/src/language.rs b/src/language.rs index 4642018e..92e784ac 100644 --- a/src/language.rs +++ b/src/language.rs @@ -42,7 +42,10 @@ use crate::isocode::{IsoCode639_1, IsoCode639_3}; )] #[serde(rename_all(serialize = "UPPERCASE", deserialize = "UPPERCASE"))] #[strum(ascii_case_insensitive)] -#[cfg_attr(feature = "python", pyo3::prelude::pyclass(eq, eq_int, frozen, hash, ord, rename_all = "UPPERCASE"))] +#[cfg_attr( + feature = "python", + pyo3::prelude::pyclass(eq, eq_int, frozen, hash, ord, rename_all = "UPPERCASE") +)] pub enum Language { #[cfg(feature = "afrikaans")] Afrikaans, @@ -1095,9 +1098,9 @@ impl Language { #[cfg(test)] mod tests { - use std::str::FromStr; - use crate::language::Language::*; + use std::str::FromStr; + use strum::ParseError::VariantNotFound; use super::*; @@ -1120,8 +1123,8 @@ mod tests { #[test] fn test_from_str() { - let language = Language::from_str("english").unwrap(); - assert_eq!(language, English); + assert_eq!(Language::from_str("english"), Ok(English)); + assert_eq!(Language::from_str("foo"), Err(VariantNotFound)); } #[test] diff --git a/src/python.rs b/src/python.rs index 0428a6e8..988d73f1 100644 --- a/src/python.rs +++ b/src/python.rs @@ -14,15 +14,15 @@ * limitations under the License. */ +use pyo3::exceptions::{PyException, PyValueError}; +use pyo3::prelude::*; +use pyo3::types::{PyTuple, PyType}; use std::any::Any; use std::collections::HashSet; use std::io; use std::panic; use std::path::PathBuf; - -use pyo3::exceptions::{PyException, PyValueError}; -use pyo3::prelude::*; -use pyo3::types::{PyTuple, PyType}; +use std::str::FromStr; use crate::builder::{ LanguageDetectorBuilder, MINIMUM_RELATIVE_DISTANCE_MESSAGE, MISSING_LANGUAGE_MESSAGE, @@ -34,6 +34,8 @@ use crate::language::Language; use crate::result::DetectionResult; use crate::writer::{LanguageModelFilesWriter, TestDataFilesWriter}; +const ENUM_MEMBER_NOT_FOUND_MESSAGE: &str = "Matching enum member not found"; + #[pymodule] fn lingua(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; @@ -169,6 +171,20 @@ impl IsoCode639_1 { fn name(&self) -> String { self.to_string().to_uppercase() } + + /// Return the ISO 639-1 code associated with the string representation + /// passed to this method. + /// + /// Raises: + /// ValueError: if there is no ISO 639-1 code for the given string representation + #[pyo3(name = "from_str")] + #[classmethod] + fn py_from_str(_cls: &Bound, string: &str) -> PyResult { + match Self::from_str(string) { + Ok(iso_code) => Ok(iso_code), + Err(_) => Err(PyValueError::new_err(ENUM_MEMBER_NOT_FOUND_MESSAGE)), + } + } } #[pymethods] @@ -177,6 +193,20 @@ impl IsoCode639_3 { fn name(&self) -> String { self.to_string().to_uppercase() } + + /// Return the ISO 639-3 code associated with the string representation + /// passed to this method. + /// + /// Raises: + /// ValueError: if there is no ISO 639-3 code for the given string representation + #[pyo3(name = "from_str")] + #[classmethod] + fn py_from_str(_cls: &Bound, string: &str) -> PyResult { + match Self::from_str(string) { + Ok(iso_code) => Ok(iso_code), + Err(_) => Err(PyValueError::new_err(ENUM_MEMBER_NOT_FOUND_MESSAGE)), + } + } } #[pymethods] @@ -245,6 +275,20 @@ impl Language { Self::from_iso_code_639_3(iso_code) } + /// Return the language associated with the string representation + /// passed to this method. + /// + /// Raises: + /// ValueError: if there is no language for the given string representation + #[pyo3(name = "from_str")] + #[classmethod] + fn py_from_str(_cls: &Bound, string: &str) -> PyResult { + match Self::from_str(string) { + Ok(language) => Ok(language), + Err(_) => Err(PyValueError::new_err(ENUM_MEMBER_NOT_FOUND_MESSAGE)), + } + } + /// Return the ISO 639-1 code of this language. #[pyo3(name = "iso_code_639_1")] #[getter] @@ -319,7 +363,10 @@ impl LanguageDetectorBuilder { /// with all built-in languages except those passed to this method. #[pyo3(name = "from_all_languages_without", signature = (*languages))] #[classmethod] - fn py_from_all_languages_without(_cls: &Bound, languages: &Bound) -> PyResult { + fn py_from_all_languages_without( + _cls: &Bound, + languages: &Bound, + ) -> PyResult { match languages.extract::>() { Ok(vector) => match panic::catch_unwind(|| Self::from_all_languages_without(&vector)) { Ok(builder) => Ok(builder), diff --git a/tests/python/test_language.py b/tests/python/test_language.py index c9d5cbe2..43308d64 100644 --- a/tests/python/test_language.py +++ b/tests/python/test_language.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from lingua import IsoCode639_1, IsoCode639_3, Language @@ -20,6 +22,14 @@ def test_iso_code_639_1_name(): assert IsoCode639_1.EN.name == "EN" +def test_iso_code_639_1_from_str(): + assert IsoCode639_1.from_str("EN") == IsoCode639_1.EN + assert IsoCode639_1.from_str("en") == IsoCode639_1.EN + assert IsoCode639_1.from_str("eN") == IsoCode639_1.EN + with pytest.raises(ValueError, match="Matching enum member not found"): + IsoCode639_1.from_str("12") + + def test_iso_code_639_1_is_comparable(): assert IsoCode639_1.EN == IsoCode639_1.EN assert IsoCode639_1.EN != IsoCode639_1.DE @@ -31,6 +41,14 @@ def test_iso_code_639_3_name(): assert IsoCode639_3.ENG.name == "ENG" +def test_iso_code_639_3_from_str(): + assert IsoCode639_3.from_str("ENG") == IsoCode639_3.ENG + assert IsoCode639_3.from_str("eng") == IsoCode639_3.ENG + assert IsoCode639_3.from_str("eNg") == IsoCode639_3.ENG + with pytest.raises(ValueError, match="Matching enum member not found"): + IsoCode639_3.from_str("123") + + def test_iso_code_639_3_is_comparable(): assert IsoCode639_3.ENG == IsoCode639_3.ENG assert IsoCode639_3.ENG != IsoCode639_3.DEU @@ -42,6 +60,14 @@ def test_language_name(): assert Language.ENGLISH.name == "ENGLISH" +def test_language_from_str(): + assert Language.from_str("ENGLISH") == Language.ENGLISH + assert Language.from_str("english") == Language.ENGLISH + assert Language.from_str("EnGlIsH") == Language.ENGLISH + with pytest.raises(ValueError, match="Matching enum member not found"): + Language.from_str("FOOBAR") + + def test_language_is_comparable(): assert Language.ENGLISH == Language.ENGLISH assert Language.ENGLISH != Language.GERMAN