Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extended features for rsdd-ocaml #7

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ crate-type = ["staticlib", "cdylib"]

[dependencies]
ocaml = { version = "^1.0.0-beta" }
rsdd = { git = "https://github.com/neuppl/rsdd", rev = "1613459" }
rsdd = { git = "https://github.com/minsungc/rsdd-dappl", rev = "e6ede39" }

[build-dependencies]
ocaml-build = {version = "^1.0.0-beta"}
74 changes: 47 additions & 27 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashMap;

use rsdd::{
builder::{bdd::RobddBuilder, cache::AllIteTable, BottomUpBuilder},
builder::{bdd::{RobddBuilder, BddBuilder}, cache::AllIteTable, BottomUpBuilder},
constants::primes,
repr::{BddPtr, Cnf, DDNNFPtr, PartialModel, VarLabel, VarOrder, WmcParams},
util::semirings::{ExpectedUtility, FiniteField, RealSemiring, Semiring},
Expand Down Expand Up @@ -40,6 +40,7 @@ unsafe impl ocaml::FromValue for RsddVarLabel {
}
}


// disc/dice interface

#[ocaml::func]
Expand All @@ -61,6 +62,24 @@ pub fn bdd_new_var(
(lbl.value(), RsddBddPtr(ptr).into())
}

#[ocaml::func]
#[ocaml::sig("int64 -> rsdd_var_label")]
pub fn mk_varlabel(
i : i64
) -> RsddVarLabel {
RsddVarLabel(VarLabel::new(i as u64))
}

#[ocaml::func]
#[ocaml::sig("rsdd_bdd_builder -> rsdd_var_label -> bool -> rsdd_bdd_ptr")]
pub fn bdd_var(
builder: &'static RsddBddBuilder,
lbl : RsddVarLabel,
polarity : bool,
) -> ocaml::Pointer<RsddBddPtr> {
RsddBddPtr(builder.0.var(lbl.0, polarity)).into()
}

#[ocaml::func]
#[ocaml::sig("rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr")]
pub fn bdd_ite(
Expand Down Expand Up @@ -101,6 +120,16 @@ pub fn bdd_negate(
RsddBddPtr(builder.0.negate(bdd.0)).into()
}

#[ocaml::func]
#[ocaml::sig("rsdd_bdd_builder -> int64 list -> rsdd_bdd_ptr")]
pub fn bdd_exactlyone(
builder: &'static RsddBddBuilder,
l : ocaml::List<i64>,
) -> ocaml::Pointer<RsddBddPtr> {
let l_of_varlabels : Vec<_> = l.into_vec().iter().map(|x| VarLabel::new_usize(*x as usize)).collect();
RsddBddPtr(builder.0.exactly_one_of_varlabels(&l_of_varlabels)).into()
}

#[ocaml::func]
#[ocaml::sig("rsdd_bdd_builder -> rsdd_bdd_ptr")]
pub fn bdd_true(builder: &'static RsddBddBuilder) -> ocaml::Pointer<RsddBddPtr> {
Expand Down Expand Up @@ -190,57 +219,48 @@ pub fn new_wmc_params_r(weights: ocaml::List<(f64, f64)>) -> ocaml::Pointer<Rsdd

// branch & bound, expected semiring items
#[ocaml::sig]
#[derive(ocaml::ToValue, ocaml::FromValue)]
pub struct RsddExpectedUtility(ExpectedUtility);
ocaml::custom!(RsddExpectedUtility);

#[ocaml::sig]
pub struct RsddWmcParamsEU(WmcParams<ExpectedUtility>);
ocaml::custom!(RsddWmcParamsEU);


#[ocaml::func]
#[ocaml::sig("rsdd_bdd_ptr -> rsdd_var_label list -> int64 -> rsdd_wmc_params_e_u -> rsdd_expected_utility * rsdd_partial_model")]
pub fn bdd_bb(
bdd: &'static RsddBddPtr,
join_vars: ocaml::List<RsddVarLabel>,
num_vars: u64,
wmc: &RsddWmcParamsEU,
) -> (
ocaml::Pointer<RsddExpectedUtility>,
ocaml::Pointer<RsddPartialModel>,
) {
let (eu, pm) = bdd.0.bb(
&join_vars
.into_linked_list()
.iter()
.map(|x| x.0)
.collect::<Vec<_>>(),
num_vars as usize,
&wmc.0,
);
(RsddExpectedUtility(eu).into(), RsddPartialModel(pm).into())
#[ocaml::sig("rsdd_expected_utility -> float * float")]
pub fn extract(
eu : RsddExpectedUtility
) -> (f64, f64) {
let v = eu.0 ;
(v.0, v.1)
}


#[ocaml::func]
#[ocaml::sig("rsdd_bdd_ptr -> rsdd_var_label list -> int64 -> rsdd_wmc_params_e_u -> rsdd_expected_utility * rsdd_partial_model")]
#[ocaml::sig("rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_var_label list -> int64 -> rsdd_wmc_params_e_u -> rsdd_expected_utility * rsdd_partial_model")]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to spin up a changelog for the package in general, but to call out - if you change the API for any existing functions, we should def mention in a PR and make this the proper semver bump!

pub fn bdd_meu(
bdd: &'static RsddBddPtr,
decision_vars: ocaml::List<RsddVarLabel>,
evidence: &'static RsddBddPtr,
join_vars: ocaml::List<RsddVarLabel>,
num_vars: u64,
wmc: &RsddWmcParamsEU,
) -> (
ocaml::Pointer<RsddExpectedUtility>,
RsddExpectedUtility,
ocaml::Pointer<RsddPartialModel>,
) {
let (eu, pm) = bdd.0.bb(
&decision_vars
let (eu, pm) = bdd.0.meu(
evidence.0,
&join_vars
.into_linked_list()
.iter()
.map(|x| x.0)
.collect::<Vec<_>>(),
num_vars as usize,
&wmc.0,
);
(RsddExpectedUtility(eu).into(), RsddPartialModel(pm).into())
(RsddExpectedUtility(eu), RsddPartialModel(pm).into())
}

#[ocaml::func]
Expand Down
91 changes: 26 additions & 65 deletions src/rsdd.ml
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need to run dune fmt to make the diff cleaner (i.e. dune build && dune fmt); I will also add this to our internal dev docs.

Original file line number Diff line number Diff line change
Expand Up @@ -12,68 +12,29 @@ type rsdd_var_label
type rsdd_wmc_params_r
type rsdd_expected_utility
type rsdd_wmc_params_e_u

external mk_bdd_builder_default_order : int64 -> rsdd_bdd_builder
= "mk_bdd_builder_default_order"

external bdd_new_var : rsdd_bdd_builder -> bool -> int64 * rsdd_bdd_ptr
= "bdd_new_var"

external bdd_ite :
rsdd_bdd_builder ->
rsdd_bdd_ptr ->
rsdd_bdd_ptr ->
rsdd_bdd_ptr ->
rsdd_bdd_ptr = "bdd_ite"

external bdd_and :
rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_and"

external bdd_or :
rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_or"

external bdd_negate : rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr
= "bdd_negate"

external bdd_true : rsdd_bdd_builder -> rsdd_bdd_ptr = "bdd_true"
external bdd_false : rsdd_bdd_builder -> rsdd_bdd_ptr = "bdd_false"
external bdd_is_true : rsdd_bdd_ptr -> bool = "bdd_is_true"
external bdd_is_false : rsdd_bdd_ptr -> bool = "bdd_is_false"
external bdd_is_const : rsdd_bdd_ptr -> bool = "bdd_is_const"

external bdd_eq : rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> bool
= "bdd_eq"

external bdd_topvar : rsdd_bdd_ptr -> int64 = "bdd_topvar"
external bdd_low : rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_low"
external bdd_high : rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_high"
external bdd_wmc : rsdd_bdd_ptr -> rsdd_wmc_params_r -> float = "bdd_wmc"

external new_wmc_params_r : (float * float) list -> rsdd_wmc_params_r
= "new_wmc_params_r"

external bdd_bb :
rsdd_bdd_ptr ->
rsdd_var_label list ->
int64 ->
rsdd_wmc_params_e_u ->
rsdd_expected_utility * rsdd_partial_model = "bdd_bb"

external bdd_meu :
rsdd_bdd_ptr ->
rsdd_var_label list ->
int64 ->
rsdd_wmc_params_e_u ->
rsdd_expected_utility * rsdd_partial_model = "bdd_meu"

external new_wmc_params_eu :
((float * float) * (float * float)) list -> rsdd_wmc_params_e_u
= "new_wmc_params_eu"

external cnf_from_dimacs : string -> rsdd_cnf = "cnf_from_dimacs"

external bdd_builder_compile_cnf : rsdd_bdd_builder -> rsdd_cnf -> rsdd_bdd_ptr
= "bdd_builder_compile_cnf"

external bdd_model_count : rsdd_bdd_builder -> rsdd_bdd_ptr -> int64
= "bdd_model_count"
external mk_bdd_builder_default_order: int64 -> rsdd_bdd_builder = "mk_bdd_builder_default_order"
external bdd_new_var: rsdd_bdd_builder -> bool -> (int64 * rsdd_bdd_ptr) = "bdd_new_var"
external mk_varlabel: int64 -> rsdd_var_label = "mk_varlabel"
external bdd_var: rsdd_bdd_builder -> rsdd_var_label -> bool -> rsdd_bdd_ptr = "bdd_var"
external bdd_ite: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_ite"
external bdd_and: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_and"
external bdd_or: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_or"
external bdd_negate: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_negate"
external bdd_exactlyone: rsdd_bdd_builder -> int64 list -> rsdd_bdd_ptr = "bdd_exactlyone"
external bdd_true: rsdd_bdd_builder -> rsdd_bdd_ptr = "bdd_true"
external bdd_false: rsdd_bdd_builder -> rsdd_bdd_ptr = "bdd_false"
external bdd_is_true: rsdd_bdd_ptr -> bool = "bdd_is_true"
external bdd_is_false: rsdd_bdd_ptr -> bool = "bdd_is_false"
external bdd_is_const: rsdd_bdd_ptr -> bool = "bdd_is_const"
external bdd_eq: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> bool = "bdd_eq"
external bdd_topvar: rsdd_bdd_ptr -> int64 = "bdd_topvar"
external bdd_low: rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_low"
external bdd_high: rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_high"
external bdd_wmc: rsdd_bdd_ptr -> rsdd_wmc_params_r -> float = "bdd_wmc"
external new_wmc_params_r: (float * float) list -> rsdd_wmc_params_r = "new_wmc_params_r"
external extract: rsdd_expected_utility -> float * float = "extract"
external bdd_meu: rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_var_label list -> int64 -> rsdd_wmc_params_e_u -> rsdd_expected_utility * rsdd_partial_model = "bdd_meu"
external new_wmc_params_eu: ((float * float) * (float * float)) list -> rsdd_wmc_params_e_u = "new_wmc_params_eu"
external cnf_from_dimacs: string -> rsdd_cnf = "cnf_from_dimacs"
external bdd_builder_compile_cnf: rsdd_bdd_builder -> rsdd_cnf -> rsdd_bdd_ptr = "bdd_builder_compile_cnf"
external bdd_model_count: rsdd_bdd_builder -> rsdd_bdd_ptr -> int64 = "bdd_model_count"
91 changes: 26 additions & 65 deletions src/rsdd.mli
Original file line number Diff line number Diff line change
Expand Up @@ -12,68 +12,29 @@ type rsdd_var_label
type rsdd_wmc_params_r
type rsdd_expected_utility
type rsdd_wmc_params_e_u

external mk_bdd_builder_default_order : int64 -> rsdd_bdd_builder
= "mk_bdd_builder_default_order"

external bdd_new_var : rsdd_bdd_builder -> bool -> int64 * rsdd_bdd_ptr
= "bdd_new_var"

external bdd_ite :
rsdd_bdd_builder ->
rsdd_bdd_ptr ->
rsdd_bdd_ptr ->
rsdd_bdd_ptr ->
rsdd_bdd_ptr = "bdd_ite"

external bdd_and :
rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_and"

external bdd_or :
rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_or"

external bdd_negate : rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr
= "bdd_negate"

external bdd_true : rsdd_bdd_builder -> rsdd_bdd_ptr = "bdd_true"
external bdd_false : rsdd_bdd_builder -> rsdd_bdd_ptr = "bdd_false"
external bdd_is_true : rsdd_bdd_ptr -> bool = "bdd_is_true"
external bdd_is_false : rsdd_bdd_ptr -> bool = "bdd_is_false"
external bdd_is_const : rsdd_bdd_ptr -> bool = "bdd_is_const"

external bdd_eq : rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> bool
= "bdd_eq"

external bdd_topvar : rsdd_bdd_ptr -> int64 = "bdd_topvar"
external bdd_low : rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_low"
external bdd_high : rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_high"
external bdd_wmc : rsdd_bdd_ptr -> rsdd_wmc_params_r -> float = "bdd_wmc"

external new_wmc_params_r : (float * float) list -> rsdd_wmc_params_r
= "new_wmc_params_r"

external bdd_bb :
rsdd_bdd_ptr ->
rsdd_var_label list ->
int64 ->
rsdd_wmc_params_e_u ->
rsdd_expected_utility * rsdd_partial_model = "bdd_bb"

external bdd_meu :
rsdd_bdd_ptr ->
rsdd_var_label list ->
int64 ->
rsdd_wmc_params_e_u ->
rsdd_expected_utility * rsdd_partial_model = "bdd_meu"

external new_wmc_params_eu :
((float * float) * (float * float)) list -> rsdd_wmc_params_e_u
= "new_wmc_params_eu"

external cnf_from_dimacs : string -> rsdd_cnf = "cnf_from_dimacs"

external bdd_builder_compile_cnf : rsdd_bdd_builder -> rsdd_cnf -> rsdd_bdd_ptr
= "bdd_builder_compile_cnf"

external bdd_model_count : rsdd_bdd_builder -> rsdd_bdd_ptr -> int64
= "bdd_model_count"
external mk_bdd_builder_default_order: int64 -> rsdd_bdd_builder = "mk_bdd_builder_default_order"
external bdd_new_var: rsdd_bdd_builder -> bool -> (int64 * rsdd_bdd_ptr) = "bdd_new_var"
external mk_varlabel: int64 -> rsdd_var_label = "mk_varlabel"
external bdd_var: rsdd_bdd_builder -> rsdd_var_label -> bool -> rsdd_bdd_ptr = "bdd_var"
external bdd_ite: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_ite"
external bdd_and: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_and"
external bdd_or: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_or"
external bdd_negate: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_negate"
external bdd_exactlyone: rsdd_bdd_builder -> int64 list -> rsdd_bdd_ptr = "bdd_exactlyone"
external bdd_true: rsdd_bdd_builder -> rsdd_bdd_ptr = "bdd_true"
external bdd_false: rsdd_bdd_builder -> rsdd_bdd_ptr = "bdd_false"
external bdd_is_true: rsdd_bdd_ptr -> bool = "bdd_is_true"
external bdd_is_false: rsdd_bdd_ptr -> bool = "bdd_is_false"
external bdd_is_const: rsdd_bdd_ptr -> bool = "bdd_is_const"
external bdd_eq: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> bool = "bdd_eq"
external bdd_topvar: rsdd_bdd_ptr -> int64 = "bdd_topvar"
external bdd_low: rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_low"
external bdd_high: rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_high"
external bdd_wmc: rsdd_bdd_ptr -> rsdd_wmc_params_r -> float = "bdd_wmc"
external new_wmc_params_r: (float * float) list -> rsdd_wmc_params_r = "new_wmc_params_r"
external extract: rsdd_expected_utility -> float * float = "extract"
external bdd_meu: rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_var_label list -> int64 -> rsdd_wmc_params_e_u -> rsdd_expected_utility * rsdd_partial_model = "bdd_meu"
external new_wmc_params_eu: ((float * float) * (float * float)) list -> rsdd_wmc_params_e_u = "new_wmc_params_eu"
external cnf_from_dimacs: string -> rsdd_cnf = "cnf_from_dimacs"
external bdd_builder_compile_cnf: rsdd_bdd_builder -> rsdd_cnf -> rsdd_bdd_ptr = "bdd_builder_compile_cnf"
external bdd_model_count: rsdd_bdd_builder -> rsdd_bdd_ptr -> int64 = "bdd_model_count"
34 changes: 32 additions & 2 deletions test/test.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,37 @@ let () =
let robdd_builder = mk_bdd_builder_default_order 6L in
let bdd =
bdd_builder_compile_cnf robdd_builder
(cnf_from_dimacs "\n p cnf 6 3\n1 2 3 4 0\n-2 -3 4 5 0\n-4 -5 6 6 0\n")
in
(cnf_from_dimacs "\n p cnf 6 3\n1 2 3 4 0\n-2 -3 4 5 0\n-4 -5 6 6 0\n") in
let model_count = bdd_model_count robdd_builder bdd in
print_endline (Int64.to_string model_count)

let () =
let builder = mk_bdd_builder_default_order 6L in
let bdd = bdd_exactlyone builder [0L; 1L; 2L ; 3L; 4L; 5L] in
let mc = bdd_model_count builder bdd in
print_endline (Int64.to_string mc)

let () =
let builder = mk_bdd_builder_default_order 0L in
let (_, flip_one_half) = bdd_new_var builder true in
let (lbl1, decision_one) = bdd_new_var builder true in
let (lbl2, decision_two) = bdd_new_var builder true in
let (_, reward_5) = bdd_new_var builder true in
let (_, reward_4) = bdd_new_var builder true in
let (_, reward_0) = bdd_new_var builder true in
let ite = bdd_ite builder flip_one_half reward_5 reward_0 in
let d1 = bdd_and builder decision_one ite in
let d2 = bdd_and builder decision_two reward_4 in
let xor = bdd_and builder (bdd_or builder decision_one decision_two) (bdd_negate builder (bdd_and builder decision_one decision_two)) in
let total = bdd_and builder (bdd_or builder d1 d2) xor in
let sdfsdf = List.map mk_varlabel [lbl1; lbl2] in
let param = new_wmc_params_eu [((0.5, 0.0), (0.5, 0.0));
((1.0, 0.0), (1.0, 0.0));
((1.0, 0.0), (1.0, 0.0));
((1.0, 0.0), (1.0, 5.0));
((1.0, 0.0), (1.0, 4.0));
((1.0, 0.0), (1.0, 0.0));] in
let (asdf, _) = bdd_meu total (bdd_true builder) sdfsdf 3L param in
let (a,b) = (extract asdf) in
let asdfasdfasdf = String.concat " " [(Float.to_string a);(Float.to_string b)] in
print_endline(asdfasdfasdf)
Loading