Skip to content

Commit

Permalink
[PROTOTYPE]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Mar 6, 2024
1 parent ff947ae commit 13362c8
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 27 deletions.
48 changes: 24 additions & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,30 +92,30 @@ bindings = 'pyo3'
features = ["pyo3/extension-module"]
include = [{ path = "path/**/*", format = "sdist" }]

[tool.ruff.lint]
extend-select = ['Q', 'RUF100', 'C90', 'I']
extend-ignore = [
'E721', # using type() instead of isinstance() - we use this in tests
]
flake8-quotes = {inline-quotes = 'single', multiline-quotes = 'double'}
mccabe = { max-complexity = 13 }
isort = { known-first-party = ['swarms_core', 'tests'] }

[tool.ruff.format]
quote-style = 'single'

[tool.pytest.ini_options]
testpaths = 'tests'
log_format = '%(name)s %(levelname)s: %(message)s'
timeout = 30
xfail_strict = true
# min, max, mean, stddev, median, iqr, outliers, ops, rounds, iterations
addopts = [
'--benchmark-columns', 'min,mean,stddev,outliers,rounds,iterations',
'--benchmark-group-by', 'group',
'--benchmark-warmup', 'on',
'--benchmark-disable', # this is enable by `make benchmark` when you actually want to run benchmarks
]
# [tool.ruff.lint]
# extend-select = ['Q', 'RUF100', 'C90', 'I']
# extend-ignore = [
# 'E721', # using type() instead of isinstance() - we use this in tests
# ]
# flake8-quotes = {inline-quotes = 'single', multiline-quotes = 'double'}
# mccabe = { max-complexity = 13 }
# isort = { known-first-party = ['swarms_core', 'tests'] }

# [tool.ruff.format]
# quote-style = 'single'

# [tool.pytest.ini_options]
# testpaths = 'tests'
# log_format = '%(name)s %(levelname)s: %(message)s'
# timeout = 30
# xfail_strict = true
# # min, max, mean, stddev, median, iqr, outliers, ops, rounds, iterations
# addopts = [
# '--benchmark-columns', 'min,mean,stddev,outliers,rounds,iterations',
# '--benchmark-group-by', 'group',
# '--benchmark-warmup', 'on',
# '--benchmark-disable', # this is enable by `make benchmark` when you actually want to run benchmarks
# ]

[tool.coverage.run]
source = ['swarms_core']
Expand Down
File renamed without changes.
3 changes: 2 additions & 1 deletion python/swarms_core/_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
import pyrust_parallel

def my_python_function():
print("Executing Python function")

g
pyrust_parallel.execute_in_parallel(my_python_function, 10)
3 changes: 1 addition & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@ fn execute_in_parallel(py: Python, py_func: PyObject, n: usize) -> PyResult<Vec<
// Use Rayon to execute the Python function in parallel
let results: Result<Vec<_>, _> = (0..n).into_par_iter()
.map(|_| {
// Temporarily release the GIL and execute the function in a new thread
Python::with_gil(|py| {
// Safely call the Python function and convert the result back to PyObject
py_func.call1(py, ()).map(|res| res.to_object(py))
})
})
.collect();

results
results.map_err(|e| e.into())
}

/// A Python module implemented in Rust.
Expand Down
45 changes: 45 additions & 0 deletions tests/test_lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#[cfg(test)]
mod tests {
use super::*;
use pyo3::types::IntoPyDict;

#[test]
fn test_execute_in_parallel() {
let gil = Python::acquire_gil();
let py = gil.python();
let py_func = py
.run(
"def test_func(): return 42",
None,
Some(py.import("builtins").unwrap().into_py_dict(py)),
)
.unwrap()
.extract::<PyObject>(py)
.unwrap();

let result = execute_in_parallel(py, py_func, 10).unwrap();
assert_eq!(result.len(), 10);
for res in result {
let res: i32 = res.extract(py).unwrap();
assert_eq!(res, 42);
}
}

#[test]
fn test_execute_in_parallel_with_error() {
let gil = Python::acquire_gil();
let py = gil.python();
let py_func = py
.run(
"def test_func(): raise ValueError('test error')",
None,
Some(py.import("builtins").unwrap().into_py_dict(py)),
)
.unwrap()
.extract::<PyObject>(py)
.unwrap();

let result = execute_in_parallel(py, py_func, 10);
assert!(result.is_err());
}
}

0 comments on commit 13362c8

Please sign in to comment.