diff --git a/Cargo.toml b/Cargo.toml index aa8ce6b..47cfead 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "stochastic-rs" -version = "0.4.1" +version = "0.4.2" edition = "2021" license = "MIT" description = "A Rust library for stochastic processes" diff --git a/src/processes/poisson.rs b/src/processes/poisson.rs index 959d550..96fa56b 100644 --- a/src/processes/poisson.rs +++ b/src/processes/poisson.rs @@ -1,28 +1,34 @@ -use ndarray::Array1; +use ndarray::{Array0, Array1, Axis, Dim}; use ndarray_rand::rand_distr::{Distribution, Exp}; use ndarray_rand::rand_distr::{Normal, Poisson}; use ndarray_rand::RandomExt; use rand::thread_rng; -pub fn poisson(n: usize, lambda: usize, t_max: Option) -> Vec { - if n == 0 || lambda == 0 { - panic!("lambda, t and n must be positive integers"); - } +pub fn poisson(lambda: f64, n: Option, t_max: Option) -> Vec { + if let Some(n) = n { + let exponentials = Array1::random(n - 1, Exp::new(lambda).unwrap()); + let mut poisson = Array1::::zeros(n); + + for i in 1..n { + poisson[i] = poisson[i - 1] + exponentials[i - 1]; + } - let t_max = t_max.unwrap_or(1.0); - let mut times = vec![0.0]; - let exp = Exp::new(lambda as f64).unwrap(); + poisson.to_vec() + } else if let Some(t_max) = t_max { + let mut poisson = Array1::from(vec![0.0]); + let mut t = 0.0; - while times.last().unwrap() < &t_max { - let inter_arrival = exp.sample(&mut thread_rng()); - let next_time = times.last().unwrap() + inter_arrival; - if next_time > t_max { - break; + while &t < &t_max { + t += Exp::new(lambda).unwrap().sample(&mut thread_rng()); + poisson + .push(Axis(0), Array0::from_elem(Dim(()), t).view()) + .unwrap(); } - times.push(next_time); - } - times + poisson.to_vec() + } else { + panic!("n or t_max must be provided"); + } } pub fn compound_poisson( @@ -66,9 +72,11 @@ mod tests { #[test] fn test_poisson() { let n = 1000; - let lambda = 10; - let t = 10.0; - let p = poisson(n, lambda, Some(t)); + let lambda = 1; + let p = poisson(lambda as f64, Some(n), None); + println!("{:?}", p.len()); + let t = 100.0; + let p = poisson(lambda as f64, None, Some(t)); println!("{:?}", p.len()); }