Skip to content

Commit

Permalink
♻️ More refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeta611 committed Jun 7, 2024
1 parent 3a512cb commit a46f49e
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 102 deletions.
18 changes: 9 additions & 9 deletions lib/compiler.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ open! Core
open Parse_tree
open Typed_tree

type env = any_det Id.Map.t
type env = some_det Id.Map.t

let gen_vertex =
let cnt = ref 0 in
Expand Down Expand Up @@ -78,7 +78,7 @@ let rec compile :
match exp with
| Value v -> (Graph.empty, { ty; exp = Value v })
| Var x -> (
let (Any { ty = tx; exp }) = Map.find_exn env x in
let (Ex { ty = tx; exp }) = Map.find_exn env x in
match (tx, ty) with
| Tyi, Tyi -> (Graph.empty, { ty; exp })
| Tyr, Tyr -> (Graph.empty, { ty; exp })
Expand Down Expand Up @@ -118,7 +118,7 @@ let rec compile :
| Let (x, e, body) ->
let g1, det_exp1 = compile env pred e in
let g2, det_exp2 =
compile (Map.set env ~key:x ~data:(Any det_exp1)) pred body
compile (Map.set env ~key:x ~data:(Ex det_exp1)) pred body
in
Graph.(g1 @| g2, det_exp2)
| Call (f, args) ->
Expand All @@ -128,7 +128,7 @@ let rec compile :
let g, de = compile env pred e in
let v = gen_vertex () in
let de_fvs = fv de.exp in
let f : any_det = Any (score de v) in
let f : some_det = Ex (score de v) in
let g' =
Graph.
{
Expand Down Expand Up @@ -180,8 +180,8 @@ let rec compile :
{
vertices = [ v ];
arcs = List.map (Set.to_list fvs) ~f:(fun z -> (z, v));
pmdf_map = Id.Map.singleton v (Any f : any_det);
obs_map = Id.Map.singleton v (Any de2 : any_det);
pmdf_map = Id.Map.singleton v (Ex f : some_det);
obs_map = Id.Map.singleton v (Ex de2 : some_det);
}
in
Graph.(g1 @| g2 @| g', de2)
Expand All @@ -197,8 +197,8 @@ and compile_args :
let g', args = compile_args env pred args in
Graph.(g @| g', arg :: args)

let compile_program (prog : program) : Graph.t * any_det =
let compile_program (prog : program) : Graph.t * some_det =
let open Typing in
let (Any e) = convert Id.Map.empty (inline prog) in
let (Ex e) = convert Id.Map.empty (inline prog) in
let g, e = compile Id.Map.empty { ty = Tyb; exp = Value true } e in
(g, Any e)
(g, Ex e)
92 changes: 48 additions & 44 deletions lib/evaluator.ml
Original file line number Diff line number Diff line change
@@ -1,48 +1,54 @@
open! Core
open Typed_tree

type env = any_v Id.Table.t
module Ctx = struct
type t = some_v Id.Table.t

let rec eval : type a. env -> (a, det) texp -> a =
fun env { ty; exp } ->
let create = Id.Table.create
let set ctx ~name ~value = Hashtbl.set ctx ~key:name ~data:value
let find_exn = Hashtbl.find_exn
end

let rec eval : type a. Ctx.t -> (a, det) texp -> a =
fun ctx { ty; exp } ->
match exp with
| Value v -> v
| Var x -> (
let (Any (tv, v)) = Hashtbl.find_exn env x in
let (Ex (tv, v)) = Ctx.find_exn ctx x in
match (ty, tv) with
| Tyi, Tyi -> v
| Tyr, Tyr -> v
| Tyb, Tyb -> v
| _ -> assert false)
| Bop (op, te1, te2) -> op.f (eval env te1) (eval env te2)
| Uop (op, te) -> op.f (eval env te)
| Bop (op, te1, te2) -> op.f (eval ctx te1) (eval ctx te2)
| Uop (op, te) -> op.f (eval ctx te)
| If (te_pred, te_cons, te_alt) ->
if eval env te_pred then eval env te_cons else eval env te_alt
| Call (f, args) -> f.sampler (eval_args env args)
if eval ctx te_pred then eval ctx te_cons else eval ctx te_alt
| Call (f, args) -> f.sampler (eval_args ctx args)

and eval_args : type a. env -> (a, det) args -> a vargs =
fun env -> function
and eval_args : type a. Ctx.t -> (a, det) args -> a vargs =
fun ctx -> function
| [] -> []
| te :: tl -> (te.ty, eval env te) :: eval_args env tl
| te :: tl -> (te.ty, eval ctx te) :: eval_args ctx tl

let rec eval_pmdf (env : env) (Any { ty; exp } : any_det) :
(any_v -> float) * any_v =
let rec eval_pmdf (ctx : Ctx.t) (Ex { ty; exp } : some_det) :
(some_v -> float) * some_v =
match exp with
| If (te_pred, te_cons, te_alt) ->
if eval env te_pred then eval_pmdf env (Any te_cons)
else eval_pmdf env (Any te_alt)
if eval ctx te_pred then eval_pmdf ctx (Ex te_cons)
else eval_pmdf ctx (Ex te_alt)
| Call (f, args) ->
let pmdf (Any (ty', v) : any_v) =
let pmdf (Ex (ty', v) : some_v) =
match (ty, ty') with
| Tyi, Tyi -> f.log_pmdf (eval_args env args) v
| Tyr, Tyr -> f.log_pmdf (eval_args env args) v
| Tyb, Tyb -> f.log_pmdf (eval_args env args) v
| Tyi, Tyi -> f.log_pmdf (eval_args ctx args) v
| Tyr, Tyr -> f.log_pmdf (eval_args ctx args) v
| Tyb, Tyb -> f.log_pmdf (eval_args ctx args) v
| _, _ -> assert false
in
(pmdf, Any (ty, eval env { ty; exp }))
(pmdf, Ex (ty, eval ctx { ty; exp }))
| _ -> (* not reachable *) assert false

let gibbs_sampling ~(num_samples : int) (graph : Graph.t) (query : any_det) :
let gibbs_sampling ~(num_samples : int) (graph : Graph.t) (query : some_det) :
float array =
(* Initialize the context with the observed values. Float conversion must
succeed as observed variables do not contain free variables *)
Expand All @@ -52,57 +58,55 @@ let gibbs_sampling ~(num_samples : int) (graph : Graph.t) (query : any_det) :
| Tyb -> false
in
let ctx = Id.Table.create () in
let () =
Map.iteri graph.obs_map ~f:(fun ~key ~data:(Any { ty; exp }) ->
let data : any_v = Any (ty, eval ctx { ty; exp }) in
Hashtbl.set ctx ~key ~data)
in
Map.iteri graph.obs_map ~f:(fun ~key:name ~data:(Ex { ty; exp }) ->
let value : some_v = Ex (ty, eval ctx { ty; exp }) in
Ctx.set ctx ~name ~value);

let unobserved = Graph.unobserved_vertices_pmdfs graph in
let () =
List.iter unobserved ~f:(fun (key, Any { ty; _ }) ->
let data : any_v = Any (ty, default ty) in
Hashtbl.set ctx ~key ~data)
in
List.iter unobserved ~f:(fun (name, Ex { ty; _ }) ->
let value : some_v = Ex (ty, default ty) in
Ctx.set ctx ~name ~value);

(* Adapted from gibbs_sampling of Owl *)
let a, b = (1000, 10) in
let num_iter = a + (b * num_samples) in
let samples = Array.create ~len:num_samples 0. in
for i = 0 to num_iter - 1 do
(* Gibbs step *)
List.iter unobserved ~f:(fun (key, exp) ->
let curr = Hashtbl.find_exn ctx key in
List.iter unobserved ~f:(fun (name, exp) ->
let curr = Ctx.find_exn ctx name in
let log_pmdf, cand = eval_pmdf ctx exp in

(* metropolis-hastings update logic *)
Hashtbl.set ctx ~key ~data:cand;
Ctx.set ctx ~name ~value:cand;
let log_pmdf', _ = eval_pmdf ctx exp in
let log_alpha = log_pmdf' curr -. log_pmdf cand in

(* variables influenced by "name" *)
let name_infl =
Map.filteri graph.pmdf_map ~f:(fun ~key:name ~data:(Any { exp; _ }) ->
Id.(name = key) || Set.mem (fv exp) key)
Map.filteri graph.pmdf_map
~f:(fun ~key:name' ~data:(Ex { exp; _ }) ->
Id.(name' = name) || Set.mem (fv exp) name)
in
let log_alpha =
Map.fold name_infl ~init:log_alpha ~f:(fun ~key:name ~data:exp acc ->
Map.fold name_infl ~init:log_alpha ~f:(fun ~key:name' ~data:exp acc ->
let prob_w_cand =
(fst (eval_pmdf ctx exp)) (Hashtbl.find_exn ctx name)
(fst (eval_pmdf ctx exp)) (Ctx.find_exn ctx name')
in
Hashtbl.set ctx ~key ~data:curr;
Ctx.set ctx ~name ~value:curr;
let prob_wo_cand =
(fst (eval_pmdf ctx exp)) (Hashtbl.find_exn ctx name)
(fst (eval_pmdf ctx exp)) (Ctx.find_exn ctx name')
in
Hashtbl.set ctx ~key ~data:cand;
Ctx.set ctx ~name ~value:cand;
acc +. prob_w_cand -. prob_wo_cand)
in

let alpha = Float.exp log_alpha in
let uniform = Owl.Stats.std_uniform_rvs () in
if Float.(uniform > alpha) then Hashtbl.set ctx ~key ~data:curr);
if Float.(uniform > alpha) then Ctx.set ctx ~name ~value:curr);

if i >= a && i mod b = 0 then
let (Any query) = query in
let (Ex query) = query in
let query =
match (query.ty, eval ctx query) with
| Tyi, i -> float_of_int i
Expand All @@ -115,7 +119,7 @@ let gibbs_sampling ~(num_samples : int) (graph : Graph.t) (query : any_det) :
samples

let infer ?(filename : string = "out") ?(num_samples : int = 100_000)
(graph : Graph.t) (query : any_det) : string =
(graph : Graph.t) (query : some_det) : string =
let samples = gibbs_sampling graph ~num_samples query in

let filename = String.chop_suffix_if_exists filename ~suffix:".stp" in
Expand Down
6 changes: 3 additions & 3 deletions lib/graph.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ open Typed_tree

type vertex = Id.t
type arc = vertex * vertex
type pmdf_map = any_det Id.Map.t
type obs_map = any_det Id.Map.t
type pmdf_map = some_det Id.Map.t
type obs_map = some_det Id.Map.t

type t = {
vertices : vertex list;
Expand Down Expand Up @@ -36,7 +36,7 @@ let union g1 g2 =
let ( @| ) = union

let unobserved_vertices_pmdfs ({ vertices; pmdf_map; obs_map; _ } : t) :
(vertex * any_det) list =
(vertex * some_det) list =
List.filter_map vertices ~f:(fun v ->
if Map.mem obs_map v then None
else
Expand Down
6 changes: 3 additions & 3 deletions lib/printing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ let of_graph ({ vertices; arcs; pmdf_map; obs_map } : Graph.t) : graph =
{
vertices;
arcs;
pmdf_map = Map.map pmdf_map ~f:(fun (Any e) -> of_exp e);
obs_map = Map.map obs_map ~f:(fun (Any e) -> of_exp e);
pmdf_map = Map.map pmdf_map ~f:(fun (Ex e) -> of_exp e);
obs_map = Map.map obs_map ~f:(fun (Ex e) -> of_exp e);
}

let to_string (Any e : any_det) = e |> of_exp |> sexp_of_t |> Sexp.to_string_hum
let to_string (Ex e : some_det) = e |> of_exp |> sexp_of_t |> Sexp.to_string_hum
13 changes: 6 additions & 7 deletions lib/typed_tree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,9 @@ and fv_args : type a. (a, det) args -> Id.Set.t = function
| [] -> Id.Set.empty
| { exp; _ } :: es -> Id.(fv exp @| fv_args es)

type any_ndet = Any : (_, non_det) texp -> any_ndet
type any_det = Any : (_, det) texp -> any_det
type any_ty = Any : _ ty -> any_ty
type any_params = Any : _ params -> any_params
type any_v = Any : ('a ty * 'a) -> any_v
type any_dist = Any : _ dist -> any_dist
type tyenv = any_ty Id.Map.t
type some_ndet = Ex : (_, non_det) texp -> some_ndet
type some_det = Ex : (_, det) texp -> some_det
type some_ty = Ex : _ ty -> some_ty
type some_params = Ex : _ params -> some_params
type some_v = Ex : ('a ty * 'a) -> some_v
type some_dist = Ex : _ dist -> some_dist
Loading

0 comments on commit a46f49e

Please sign in to comment.