Skip to content

Commit

Permalink
Convert string representations of languages and ISO codes to actual t…
Browse files Browse the repository at this point in the history
…ypes in Python bindings (#411)
  • Loading branch information
pemistahl authored Dec 5, 2024
1 parent 910a7ee commit bbe3dc0
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 16 deletions.
27 changes: 27 additions & 0 deletions lingua.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,15 @@ class Language(Enum):
ValueError: if there is no language for the given ISO code
"""

@classmethod
def from_str(cls, string: str) -> "Language":
"""Return the language associated with the string representation
passed to this method.
Raises:
ValueError: if there is no language for the given string representation
"""


class IsoCode639_1(Enum):
"""This enum specifies the ISO 639-1 code representations for the
Expand Down Expand Up @@ -282,6 +291,15 @@ class IsoCode639_1(Enum):
ZH = 74
ZU = 75

@classmethod
def from_str(cls, string: str) -> "Language":
"""Return the ISO 639-1 code associated with the string representation
passed to this method.
Raises:
ValueError: if there is no ISO 639-1 code for the given string representation
"""


class IsoCode639_3(Enum):
"""This enum specifies the ISO 639-3 code representations for the
Expand Down Expand Up @@ -366,6 +384,15 @@ class IsoCode639_3(Enum):
ZHO = 74
ZUL = 75

@classmethod
def from_str(cls, string: str) -> "Language":
"""Return the ISO 639-3 code associated with the string representation
passed to this method.
Raises:
ValueError: if there is no ISO 639-3 code for the given string representation
"""


class LanguageDetector:
"""This class detects the language of text."""
Expand Down
20 changes: 14 additions & 6 deletions src/isocode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ use strum_macros::{EnumIter, EnumString};
)]
#[allow(clippy::upper_case_acronyms)]
#[strum(ascii_case_insensitive)]
#[cfg_attr(feature = "python", pyo3::prelude::pyclass(eq, eq_int, frozen, hash, ord))]
#[cfg_attr(
feature = "python",
pyo3::prelude::pyclass(eq, eq_int, frozen, hash, ord)
)]
pub enum IsoCode639_1 {
#[cfg(feature = "afrikaans")]
/// The ISO 639-1 code for [`Afrikaans`](crate::language::Language::Afrikaans)
Expand Down Expand Up @@ -360,7 +363,10 @@ pub enum IsoCode639_1 {
)]
#[allow(clippy::upper_case_acronyms)]
#[strum(ascii_case_insensitive)]
#[cfg_attr(feature = "python", pyo3::prelude::pyclass(eq, eq_int, frozen, hash, ord))]
#[cfg_attr(
feature = "python",
pyo3::prelude::pyclass(eq, eq_int, frozen, hash, ord)
)]
pub enum IsoCode639_3 {
#[cfg(feature = "afrikaans")]
/// The ISO 639-3 code for [`Afrikaans`](crate::language::Language::Afrikaans)
Expand Down Expand Up @@ -679,9 +685,9 @@ impl Display for IsoCode639_3 {

#[cfg(test)]
mod tests {
use std::str::FromStr;

use super::*;
use std::str::FromStr;
use strum::ParseError::VariantNotFound;

#[test]
fn assert_iso_code_639_1_string_representation_is_correct() {
Expand All @@ -695,11 +701,13 @@ mod tests {

#[test]
fn assert_string_to_iso_code_639_1_is_correct() {
assert_eq!(IsoCode639_1::from_str("en").unwrap(), IsoCode639_1::EN);
assert_eq!(IsoCode639_1::from_str("en"), Ok(IsoCode639_1::EN));
assert_eq!(IsoCode639_1::from_str("12"), Err(VariantNotFound));
}

#[test]
fn assert_string_to_iso_code_639_3_is_correct() {
assert_eq!(IsoCode639_3::from_str("eng").unwrap(), IsoCode639_3::ENG);
assert_eq!(IsoCode639_3::from_str("eng"), Ok(IsoCode639_3::ENG));
assert_eq!(IsoCode639_3::from_str("123"), Err(VariantNotFound));
}
}
13 changes: 8 additions & 5 deletions src/language.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ use crate::isocode::{IsoCode639_1, IsoCode639_3};
)]
#[serde(rename_all(serialize = "UPPERCASE", deserialize = "UPPERCASE"))]
#[strum(ascii_case_insensitive)]
#[cfg_attr(feature = "python", pyo3::prelude::pyclass(eq, eq_int, frozen, hash, ord, rename_all = "UPPERCASE"))]
#[cfg_attr(
feature = "python",
pyo3::prelude::pyclass(eq, eq_int, frozen, hash, ord, rename_all = "UPPERCASE")
)]
pub enum Language {
#[cfg(feature = "afrikaans")]
Afrikaans,
Expand Down Expand Up @@ -1095,9 +1098,9 @@ impl Language {

#[cfg(test)]
mod tests {
use std::str::FromStr;

use crate::language::Language::*;
use std::str::FromStr;
use strum::ParseError::VariantNotFound;

use super::*;

Expand All @@ -1120,8 +1123,8 @@ mod tests {

#[test]
fn test_from_str() {
let language = Language::from_str("english").unwrap();
assert_eq!(language, English);
assert_eq!(Language::from_str("english"), Ok(English));
assert_eq!(Language::from_str("foo"), Err(VariantNotFound));
}

#[test]
Expand Down
57 changes: 52 additions & 5 deletions src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
* limitations under the License.
*/

use pyo3::exceptions::{PyException, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyTuple, PyType};
use std::any::Any;
use std::collections::HashSet;
use std::io;
use std::panic;
use std::path::PathBuf;

use pyo3::exceptions::{PyException, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyTuple, PyType};
use std::str::FromStr;

use crate::builder::{
LanguageDetectorBuilder, MINIMUM_RELATIVE_DISTANCE_MESSAGE, MISSING_LANGUAGE_MESSAGE,
Expand All @@ -34,6 +34,8 @@ use crate::language::Language;
use crate::result::DetectionResult;
use crate::writer::{LanguageModelFilesWriter, TestDataFilesWriter};

const ENUM_MEMBER_NOT_FOUND_MESSAGE: &str = "Matching enum member not found";

#[pymodule]
fn lingua(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<ConfidenceValue>()?;
Expand Down Expand Up @@ -169,6 +171,20 @@ impl IsoCode639_1 {
fn name(&self) -> String {
self.to_string().to_uppercase()
}

/// Return the ISO 639-1 code associated with the string representation
/// passed to this method.
///
/// Raises:
/// ValueError: if there is no ISO 639-1 code for the given string representation
#[pyo3(name = "from_str")]
#[classmethod]
fn py_from_str(_cls: &Bound<PyType>, string: &str) -> PyResult<Self> {
match Self::from_str(string) {
Ok(iso_code) => Ok(iso_code),
Err(_) => Err(PyValueError::new_err(ENUM_MEMBER_NOT_FOUND_MESSAGE)),
}
}
}

#[pymethods]
Expand All @@ -177,6 +193,20 @@ impl IsoCode639_3 {
fn name(&self) -> String {
self.to_string().to_uppercase()
}

/// Return the ISO 639-3 code associated with the string representation
/// passed to this method.
///
/// Raises:
/// ValueError: if there is no ISO 639-3 code for the given string representation
#[pyo3(name = "from_str")]
#[classmethod]
fn py_from_str(_cls: &Bound<PyType>, string: &str) -> PyResult<Self> {
match Self::from_str(string) {
Ok(iso_code) => Ok(iso_code),
Err(_) => Err(PyValueError::new_err(ENUM_MEMBER_NOT_FOUND_MESSAGE)),
}
}
}

#[pymethods]
Expand Down Expand Up @@ -245,6 +275,20 @@ impl Language {
Self::from_iso_code_639_3(iso_code)
}

/// Return the language associated with the string representation
/// passed to this method.
///
/// Raises:
/// ValueError: if there is no language for the given string representation
#[pyo3(name = "from_str")]
#[classmethod]
fn py_from_str(_cls: &Bound<PyType>, string: &str) -> PyResult<Self> {
match Self::from_str(string) {
Ok(language) => Ok(language),
Err(_) => Err(PyValueError::new_err(ENUM_MEMBER_NOT_FOUND_MESSAGE)),
}
}

/// Return the ISO 639-1 code of this language.
#[pyo3(name = "iso_code_639_1")]
#[getter]
Expand Down Expand Up @@ -319,7 +363,10 @@ impl LanguageDetectorBuilder {
/// with all built-in languages except those passed to this method.
#[pyo3(name = "from_all_languages_without", signature = (*languages))]
#[classmethod]
fn py_from_all_languages_without(_cls: &Bound<PyType>, languages: &Bound<PyTuple>) -> PyResult<Self> {
fn py_from_all_languages_without(
_cls: &Bound<PyType>,
languages: &Bound<PyTuple>,
) -> PyResult<Self> {
match languages.extract::<Vec<Language>>() {
Ok(vector) => match panic::catch_unwind(|| Self::from_all_languages_without(&vector)) {
Ok(builder) => Ok(builder),
Expand Down
26 changes: 26 additions & 0 deletions tests/python/test_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from lingua import IsoCode639_1, IsoCode639_3, Language


def test_iso_code_639_1_name():
assert IsoCode639_1.EN.name == "EN"


def test_iso_code_639_1_from_str():
assert IsoCode639_1.from_str("EN") == IsoCode639_1.EN
assert IsoCode639_1.from_str("en") == IsoCode639_1.EN
assert IsoCode639_1.from_str("eN") == IsoCode639_1.EN
with pytest.raises(ValueError, match="Matching enum member not found"):
IsoCode639_1.from_str("12")


def test_iso_code_639_1_is_comparable():
assert IsoCode639_1.EN == IsoCode639_1.EN
assert IsoCode639_1.EN != IsoCode639_1.DE
Expand All @@ -31,6 +41,14 @@ def test_iso_code_639_3_name():
assert IsoCode639_3.ENG.name == "ENG"


def test_iso_code_639_3_from_str():
assert IsoCode639_3.from_str("ENG") == IsoCode639_3.ENG
assert IsoCode639_3.from_str("eng") == IsoCode639_3.ENG
assert IsoCode639_3.from_str("eNg") == IsoCode639_3.ENG
with pytest.raises(ValueError, match="Matching enum member not found"):
IsoCode639_3.from_str("123")


def test_iso_code_639_3_is_comparable():
assert IsoCode639_3.ENG == IsoCode639_3.ENG
assert IsoCode639_3.ENG != IsoCode639_3.DEU
Expand All @@ -42,6 +60,14 @@ def test_language_name():
assert Language.ENGLISH.name == "ENGLISH"


def test_language_from_str():
assert Language.from_str("ENGLISH") == Language.ENGLISH
assert Language.from_str("english") == Language.ENGLISH
assert Language.from_str("EnGlIsH") == Language.ENGLISH
with pytest.raises(ValueError, match="Matching enum member not found"):
Language.from_str("FOOBAR")


def test_language_is_comparable():
assert Language.ENGLISH == Language.ENGLISH
assert Language.ENGLISH != Language.GERMAN
Expand Down

0 comments on commit bbe3dc0

Please sign in to comment.