Skip to content

Commit

Permalink
add bnn to appmc
Browse files Browse the repository at this point in the history
  • Loading branch information
AL-JiongYang committed Jan 25, 2025
1 parent 8590458 commit 1f67a1b
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 1 deletion.
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ if(POLICY CMP0026)
cmake_policy(SET CMP0026 NEW)
endif()

if (ENABLE_BNN)
add_definitions( -DENABLE_BNN )
endif()

# -----------------------------------------------------------------------------
# Provide scripts dir for included cmakes to use
# -----------------------------------------------------------------------------
Expand Down
10 changes: 10 additions & 0 deletions src/approxmc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,16 @@ DLL_PUBLIC bool AppMC::add_xor_clause(const vector<uint32_t>& vars, bool rhs)
return data->counter.solver_add_xor_clause(vars, rhs);
}

#ifdef ENABLE_BNN
DLL_PUBLIC bool AppMC::add_bnn_clause(
const std::vector<Lit>& lits,
signed cutoff,
Lit out )
{
return data->counter.solver_add_bnn_clause(lits, cutoff, out);
}
#endif

DLL_PUBLIC CMSat::SATSolver* AppMC::get_solver()
{
return data->counter.solver;
Expand Down
7 changes: 7 additions & 0 deletions src/approxmc.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ class AppMC
bool add_red_clause(const std::vector<CMSat::Lit>& lits);
bool add_xor_clause(const std::vector<CMSat::Lit>& lits, bool rhs);
bool add_xor_clause(const std::vector<uint32_t>& vars, bool rhs);
#ifdef ENABLE_BNN
bool add_bnn_clause(
const std::vector<CMSat::Lit>& lits,
signed cutoff,
CMSat::Lit out = CMSat::lit_Undef
);
#endif

// Information about approxmc
std::string get_version_info();
Expand Down
22 changes: 21 additions & 1 deletion src/counter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ bool Counter::solver_add_xor_clause(const vector<uint32_t>& vars, const bool rhs
return solver->add_xor_clause(vars, rhs);
}

#ifdef ENABLE_BNN
bool Counter::solver_add_bnn_clause(
const std::vector<Lit>& lits,
signed cutoff,
Lit out )
{
if (conf.dump_intermediary_cnf) bnns_in_solver.push_back(make_pair(lits, make_pair(cutoff, out)));
return solver->add_bnn_clause(lits, cutoff, out);
}
#endif

Hash Counter::add_hash(uint32_t hash_index, SparseData& sparse_data)
{
string random_bits;
Expand Down Expand Up @@ -200,7 +211,7 @@ void Counter::dump_cnf_from_solver(const vector<Lit>& assumps, const uint32_t it

std::ofstream f;
f.open(ss.str(), std::ios::out);
f << "p cnf " << solver->nVars()+1 << " " << cls_in_solver.size()+xors_in_solver.size()+assumps.size() << endl;
f << "p cnf " << solver->nVars()+1 << " " << cls_in_solver.size()+xors_in_solver.size()+bnns_in_solver.size()+assumps.size() << endl;
for(const auto& cl: cls_in_solver) f << cl << " 0" << endl;
f << "c XORs below" << endl;
for(const auto& x: xors_in_solver) {
Expand All @@ -213,6 +224,14 @@ void Counter::dump_cnf_from_solver(const vector<Lit>& assumps, const uint32_t it
}
f << "0" << endl;
}
f << "c BNNs below" << endl;
for(const auto& bnn: bnns_in_solver) {
f << "b ";
for(uint32_t i = 0; i < bnn.first.size(); i++) {
f << bnn.first[i] << " ";
}
f << "0 " << bnn.second.first << " " << bnn.second.second << " 0" << endl;
}
f << "c assumptions below" << endl;
for(const auto& l: assumps) f << l << " 0" << endl;
f.close();
Expand Down Expand Up @@ -903,6 +922,7 @@ void Counter::check_model(
}
assert(sat);
}
// todo: add check bnn
}

if (!hm) return;
Expand Down
8 changes: 8 additions & 0 deletions src/counter.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ class Counter {
bool solver_add_clause(const vector<Lit>& cl);
bool solver_add_xor_clause(const vector<uint32_t>& vars, const bool rhs);
bool solver_add_xor_clause(const vector<Lit>& lits, const bool rhs);
#ifdef ENABLE_BNN
bool solver_add_bnn_clause(
const std::vector<Lit>& lits,
signed cutoff,
Lit out = lit_Undef
);
#endif

private:
Config& conf;
Expand Down Expand Up @@ -193,6 +200,7 @@ class Counter {
uint32_t base_rand = 0;
vector<vector<Lit>> cls_in_solver; // needed for accurate dumping
vector<pair<vector<Lit>, bool>> xors_in_solver; // needed for accurate dumping
vector<pair<vector<Lit>, pair<int32_t, Lit>>> bnns_in_solver; // needed for accurate dumping

int argc;
char** argv;
Expand Down

0 comments on commit 1f67a1b

Please sign in to comment.