Skip to content

Commit

Permalink
better python interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshix-1 committed Nov 23, 2023
1 parent 93ad2a2 commit 0e3bc3f
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 35 deletions.
2 changes: 1 addition & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ pub enum Language {{
}}
impl Language {{
pub fn read_words(self, length: usize) -> StringChunkIter<'static> {{
pub fn read_words(self, length: usize) -> StringChunkIter {{
let words: &'static str = match self {{
{}
}};
Expand Down
77 changes: 69 additions & 8 deletions src/language/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,33 @@
// SPDX-License-Identifier: EUPL-1.2

use cfg_if::cfg_if;

#[cfg(feature = "pyo3")]
use pyo3::create_exception;
#[cfg(feature = "pyo3")]
use pyo3::exceptions::PyValueError;
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;

pub struct StringChunkIter<'a> {
word_length: usize,
index: usize,
string: &'a str,
cfg_if! {
if #[cfg(feature = "pyo3")] {
#[pyclass]
pub struct StringChunkIter {
word_length: usize,
index: usize,
string: &'static str,
}
} else {
pub struct StringChunkIter {
word_length: usize,
index: usize,
string: &'static str,
}
}
}

impl<'a> StringChunkIter<'a> {
pub fn new(word_length: usize, string: &'a str) -> StringChunkIter<'a> {
impl StringChunkIter {
pub fn new(word_length: usize, string: &'static str) -> StringChunkIter {
StringChunkIter {
word_length,
index: 0,
Expand All @@ -19,8 +36,8 @@ impl<'a> StringChunkIter<'a> {
}
}

impl<'a> Iterator for StringChunkIter<'a> {
type Item = &'a str;
impl Iterator for StringChunkIter {
type Item = &'static str;

fn next(&mut self) -> Option<Self::Item> {
let index = self.index;
Expand All @@ -33,4 +50,48 @@ impl<'a> Iterator for StringChunkIter<'a> {
}
}

#[cfg(feature = "pyo3")]
#[pymethods]
impl StringChunkIter {
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}

fn __next__(&mut self) -> Option<&'static str> {
self.next()
}

fn __len__(&self) -> usize {
self.string.len() / self.word_length
}
}

include!(concat!(env!("OUT_DIR"), "/language.rs"));

#[cfg(feature = "pyo3")]
create_exception!(hangman_solver, UnknownLanguageError, PyValueError);

#[cfg(feature = "pyo3")]
#[pymethods]
impl Language {
#[getter]
fn value(&self) -> &'static str {
self.name()
}

#[staticmethod]
fn values() -> Vec<Language> {
Language::all()
}

#[staticmethod]
#[pyo3(signature = (name, default = None))]
pub fn parse_string(
name: &str,
default: Option<Language>,
) -> PyResult<Language> {
Language::from_string(name)
.or(default)
.ok_or(UnknownLanguageError::new_err(name.to_owned()))
}
}
34 changes: 9 additions & 25 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
mod language;
mod solver;

pub use crate::language::Language;
pub use crate::language::{Language, StringChunkIter};

pub use crate::solver::{solve_hangman_puzzle, HangmanResult};

#[cfg(feature = "pyo3")]
use pyo3::create_exception;
#[cfg(feature = "pyo3")]
use pyo3::exceptions::PyValueError;
pub use crate::language::UnknownLanguageError;
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;

Expand All @@ -27,35 +25,21 @@ pub fn solve(
))
}

#[cfg(feature = "pyo3")]
create_exception!(hangman_solver, UnknownLanguageError, PyValueError);

#[cfg(feature = "pyo3")]
#[pyfunction]
#[pyo3(signature = (name, default = None))]
pub fn parse_language(
name: &str,
default: Option<Language>,
) -> PyResult<Language> {
Language::from_string(name)
.or(default)
.ok_or(UnknownLanguageError::new_err(name.to_owned()))
#[pyo3(signature = (language, word_length))]
pub fn read_words_with_length(
language: Language,
word_length: usize,
) -> PyResult<StringChunkIter> {
Ok(language.read_words(word_length))
}

// #[pyfunction]
// #[pyo3(signature = (language, word_length))]
// pub fn read_words_with_length(
// language: Language,
// word_length: usize,
// ) -> PyResult<StringChunkIter<'static>> {
// Ok(language.read_words(word_length))
// }

#[cfg(feature = "pyo3")]
#[pymodule]
pub(crate) fn hangman_solver(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(solve, m)?)?;
m.add_function(wrap_pyfunction!(parse_language, m)?)?;
m.add_function(wrap_pyfunction!(read_words_with_length, m)?)?;
m.add(
"UnknownLanguageError",
py.get_type::<UnknownLanguageError>(),
Expand Down
2 changes: 1 addition & 1 deletion src/solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ impl std::fmt::Display for HangmanResult {
}
}

fn read_words(language: Language, length: usize) -> StringChunkIter<'static> {
fn read_words(language: Language, length: usize) -> StringChunkIter {
language.read_words(length)
}

Expand Down

0 comments on commit 0e3bc3f

Please sign in to comment.