From 0e3bc3f8ca943b6a8a92e9a5102cc5d2cb2ba86a Mon Sep 17 00:00:00 2001 From: Joshix Date: Thu, 23 Nov 2023 18:00:00 +0000 Subject: [PATCH] better python interface --- build.rs | 2 +- src/language/mod.rs | 77 ++++++++++++++++++++++++++++++++++++++++----- src/lib.rs | 34 ++++++-------------- src/solver/mod.rs | 2 +- 4 files changed, 80 insertions(+), 35 deletions(-) diff --git a/build.rs b/build.rs index 3b7ba63..a738a5e 100644 --- a/build.rs +++ b/build.rs @@ -161,7 +161,7 @@ pub enum Language {{ }} impl Language {{ - pub fn read_words(self, length: usize) -> StringChunkIter<'static> {{ + pub fn read_words(self, length: usize) -> StringChunkIter {{ let words: &'static str = match self {{ {} }}; diff --git a/src/language/mod.rs b/src/language/mod.rs index bb8c605..fb58e6e 100644 --- a/src/language/mod.rs +++ b/src/language/mod.rs @@ -1,16 +1,33 @@ // SPDX-License-Identifier: EUPL-1.2 +use cfg_if::cfg_if; + +#[cfg(feature = "pyo3")] +use pyo3::create_exception; +#[cfg(feature = "pyo3")] +use pyo3::exceptions::PyValueError; #[cfg(feature = "pyo3")] use pyo3::prelude::*; -pub struct StringChunkIter<'a> { - word_length: usize, - index: usize, - string: &'a str, +cfg_if! { + if #[cfg(feature = "pyo3")] { + #[pyclass] + pub struct StringChunkIter { + word_length: usize, + index: usize, + string: &'static str, + } + } else { + pub struct StringChunkIter { + word_length: usize, + index: usize, + string: &'static str, + } + } } -impl<'a> StringChunkIter<'a> { - pub fn new(word_length: usize, string: &'a str) -> StringChunkIter<'a> { +impl StringChunkIter { + pub fn new(word_length: usize, string: &'static str) -> StringChunkIter { StringChunkIter { word_length, index: 0, @@ -19,8 +36,8 @@ impl<'a> StringChunkIter<'a> { } } -impl<'a> Iterator for StringChunkIter<'a> { - type Item = &'a str; +impl Iterator for StringChunkIter { + type Item = &'static str; fn next(&mut self) -> Option { let index = self.index; @@ -33,4 +50,48 @@ impl<'a> Iterator for StringChunkIter<'a> { } } +#[cfg(feature = "pyo3")] +#[pymethods] +impl StringChunkIter { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(&mut self) -> Option<&'static str> { + self.next() + } + + fn __len__(&self) -> usize { + self.string.len() / self.word_length + } +} + include!(concat!(env!("OUT_DIR"), "/language.rs")); + +#[cfg(feature = "pyo3")] +create_exception!(hangman_solver, UnknownLanguageError, PyValueError); + +#[cfg(feature = "pyo3")] +#[pymethods] +impl Language { + #[getter] + fn value(&self) -> &'static str { + self.name() + } + + #[staticmethod] + fn values() -> Vec { + Language::all() + } + + #[staticmethod] + #[pyo3(signature = (name, default = None))] + pub fn parse_string( + name: &str, + default: Option, + ) -> PyResult { + Language::from_string(name) + .or(default) + .ok_or(UnknownLanguageError::new_err(name.to_owned())) + } +} diff --git a/src/lib.rs b/src/lib.rs index 6db3180..6f75fac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,12 @@ mod language; mod solver; -pub use crate::language::Language; +pub use crate::language::{Language, StringChunkIter}; pub use crate::solver::{solve_hangman_puzzle, HangmanResult}; #[cfg(feature = "pyo3")] -use pyo3::create_exception; -#[cfg(feature = "pyo3")] -use pyo3::exceptions::PyValueError; +pub use crate::language::UnknownLanguageError; #[cfg(feature = "pyo3")] use pyo3::prelude::*; @@ -27,35 +25,21 @@ pub fn solve( )) } -#[cfg(feature = "pyo3")] -create_exception!(hangman_solver, UnknownLanguageError, PyValueError); - #[cfg(feature = "pyo3")] #[pyfunction] -#[pyo3(signature = (name, default = None))] -pub fn parse_language( - name: &str, - default: Option, -) -> PyResult { - Language::from_string(name) - .or(default) - .ok_or(UnknownLanguageError::new_err(name.to_owned())) +#[pyo3(signature = (language, word_length))] +pub fn read_words_with_length( + language: Language, + word_length: usize, +) -> PyResult { + Ok(language.read_words(word_length)) } -// #[pyfunction] -// #[pyo3(signature = (language, word_length))] -// pub fn read_words_with_length( -// language: Language, -// word_length: usize, -// ) -> PyResult> { -// Ok(language.read_words(word_length)) -// } - #[cfg(feature = "pyo3")] #[pymodule] pub(crate) fn hangman_solver(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(solve, m)?)?; - m.add_function(wrap_pyfunction!(parse_language, m)?)?; + m.add_function(wrap_pyfunction!(read_words_with_length, m)?)?; m.add( "UnknownLanguageError", py.get_type::(), diff --git a/src/solver/mod.rs b/src/solver/mod.rs index 71f80d9..ff21704 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -125,7 +125,7 @@ impl std::fmt::Display for HangmanResult { } } -fn read_words(language: Language, length: usize) -> StringChunkIter<'static> { +fn read_words(language: Language, length: usize) -> StringChunkIter { language.read_words(length) }