diff --git a/dune-project b/dune-project index e7b51c4..586b673 100644 --- a/dune-project +++ b/dune-project @@ -29,7 +29,7 @@ (fmt (>= 0.9.0)) (logs (>= 0.7.0)) (owl (>= 1.1)) - (owl-plplot (>= 1.0)) + (pyml (>= 20231101)) (string_dict (>= 0.16.0)) (ppx_jane (>= 0.16.0)) (menhir (>= 20231231)))) diff --git a/lib/dune b/lib/dune index bf16f5d..70f649c 100644 --- a/lib/dune +++ b/lib/dune @@ -1,6 +1,6 @@ (library (name stappl) - (libraries core owl owl-plplot string_dict logs) + (libraries core owl pyml string_dict logs) (inline_tests) (preprocess (pps ppx_jane))) diff --git a/lib/evaluator.ml b/lib/evaluator.ml index 8a6d764..803625c 100644 --- a/lib/evaluator.ml +++ b/lib/evaluator.ml @@ -56,7 +56,7 @@ let rec eval_pmdf : (pmdf, Ex (dty, eval_dist ctx { ty; exp })) let gibbs_sampling ~(num_samples : int) (graph : Graph.t) (Ex query : query) : - float array = + floatarray = (* Initialize the context with the observed values. Float conversion must succeed as observed variables do not contain free variables *) let default : type a. a dty -> a = function @@ -79,7 +79,7 @@ let gibbs_sampling ~(num_samples : int) (graph : Graph.t) (Ex query : query) : (* Adapted from gibbs_sampling of Owl *) let a, b = (1000, 10) in let num_iter = a + (b * num_samples) in - let samples = Array.create ~len:num_samples 0. in + let samples = Stdlib.Float.Array.init num_samples (fun _ -> 0.) in for i = 0 to num_iter - 1 do (* Gibbs step *) List.iter unobserved ~f:(fun (name, Ex exp) -> @@ -122,7 +122,7 @@ let gibbs_sampling ~(num_samples : int) (graph : Graph.t) (Ex query : query) : | Tyi, i -> float_of_int i | Tyr, r -> r in - samples.((i - a) / b) <- query + Stdlib.Float.Array.set samples ((i - a) / b) query done; samples @@ -134,11 +134,8 @@ let infer ?(filename : string = "out") ?(num_samples : int = 100_000) let filename = String.chop_suffix_if_exists filename ~suffix:".stp" in let plot_path = filename ^ ".png" in - let open Owl_plplot in - let h = Plot.create plot_path in - Plot.set_title h - Typed_tree.Erased.([%sexp (of_rv query : exp)] |> Sexp.to_string); - let mat = Owl.Mat.of_array samples 1 num_samples in - Plot.histogram ~h ~bin:50 mat; - Plot.output h; + Plot.draw ~plot_path + ~title:Typed_tree.Erased.([%sexp (of_rv query : exp)] |> Sexp.to_string) + ~samples ~num_samples; + plot_path diff --git a/lib/plot.ml b/lib/plot.ml new file mode 100644 index 0000000..6368a2c --- /dev/null +++ b/lib/plot.ml @@ -0,0 +1,27 @@ +open! Core + +let () = Py.initialize () + +let draw ~plot_path ~title ~samples ~num_samples = + let m = Py.Import.add_module "ocaml" in + List.iter2_exn ~f:(Py.Module.set m) + [ "plot_path"; "title"; "samples"; "num_samples" ] + [ + Py.String.of_string plot_path; + Py.String.of_string title; + Py.Array.numpy samples; + Py.Int.of_int num_samples; + ]; + + Py.Run.eval ~start:Py.File + {| +from ocaml import plot_path, title, samples, num_samples +import seaborn as sns + +sns.set_theme() + +g = sns.displot(samples, element="step", stat="probability", bins=num_samples // 1500) +g.set_titles(title) +g.tight_layout() +g.savefig(plot_path) |} + |> ignore diff --git a/stappl.opam b/stappl.opam index 4f77437..c2b9536 100644 --- a/stappl.opam +++ b/stappl.opam @@ -15,7 +15,7 @@ depends: [ "fmt" {>= "0.9.0"} "logs" {>= "0.7.0"} "owl" {>= "1.1"} - "owl-plplot" {>= "1.0"} + "pyml" {>= "20231101"} "string_dict" {>= "0.16.0"} "ppx_jane" {>= "0.16.0"} "menhir" {>= "20231231"} @@ -38,5 +38,5 @@ build: [ dev-repo: "git+https://github.com/shapespeare/stappl.git" pin-depends: [ [ "owl.1.1" "git+https://github.com/owlbarn/owl#06943b0267e7e80dd0eba94ebf63ca4d25c71910" ] - [ "owl-plplot.1.0" "git+https://github.com/owlbarn/owl-plplot#ebc73c09a907c1c6ca2c4b970bdb70202ec90b50" ] + [ "pyml.20231101" "git+https://github.com/Zeta611/pyml#d62a7b9c2e3a856121c9cc850d71a11b00243b0c" ] ] diff --git a/stappl.opam.template b/stappl.opam.template index ac93838..1d1f037 100644 --- a/stappl.opam.template +++ b/stappl.opam.template @@ -1,4 +1,4 @@ pin-depends: [ [ "owl.1.1" "git+https://github.com/owlbarn/owl#06943b0267e7e80dd0eba94ebf63ca4d25c71910" ] - [ "owl-plplot.1.0" "git+https://github.com/owlbarn/owl-plplot#ebc73c09a907c1c6ca2c4b970bdb70202ec90b50" ] + [ "pyml.20231101" "git+https://github.com/Zeta611/pyml#d62a7b9c2e3a856121c9cc850d71a11b00243b0c" ] ]