Skip to content

Commit

Permalink
Address possible deadlock situation
Browse files Browse the repository at this point in the history
Signed-off-by: Marco Lampacrescia <[email protected]>
  • Loading branch information
MarcoLm993 committed Aug 9, 2024
1 parent ba85603 commit db115e5
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 9 deletions.
7 changes: 7 additions & 0 deletions include/samples/batch_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ class BatchBuffer {
BatchBuffer(const size_t n_threads, const size_t n_slots);
~BatchBuffer() = default;

/*!
* @brief Cancel the buffer, unblocking all threads and preventing new samples from being added
*/
void cancel();

/*!
* @brief Add results to the buffer. Will throw if the buffer is full
* @param results The results to add
Expand All @@ -64,6 +69,8 @@ class BatchBuffer {

private:
const size_t _n_slots;
// Whether the buffer is still active or it was canceled
bool _ok;
// TODO: Consider to use a ring-buffer instead of a deque or vector of Batch results
std::vector<std::deque<BatchResults>> _results_buffer;
mutable std::mutex _buffer_mutex;
Expand Down
7 changes: 7 additions & 0 deletions include/samples/sampling_results.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ class SamplingResults {
*/
void processBatchResults(const BatchResults& res);

/*!
* @brief Check if we can continue the sampling process (not yet converged to a result and valid samples). Update the internal variable
*/
void updateSamplingStatus();

/*!
* @brief Check if we have reached the minimum n. of iterations
* @return Whether we have reached the minimum n. of iterations
Expand Down Expand Up @@ -161,6 +166,8 @@ class SamplingResults {
mutable std::mutex _mtx;
// The kind of property we need to evaluate
const state_properties::PropertyType _property_type;
// Whether we should continue sampling or not (since convergence is reached or other conditions fired)
bool _keep_sampling;
// Variables to keep track of the sampled traces results (Used for P properties)
size_t _n_verified;
size_t _n_not_verified;
Expand Down
19 changes: 16 additions & 3 deletions src/samples/batch_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,32 @@ namespace smc_storm::samples {
BatchBuffer::BatchBuffer(const size_t n_threads, const size_t n_slots) : _n_slots{n_slots}, _results_buffer(n_threads) {
STORM_LOG_THROW(n_threads > 0U, storm::exceptions::IllegalArgumentException, "Number of threads must be positive");
STORM_LOG_THROW(n_slots > 0U, storm::exceptions::IllegalArgumentException, "Number of slots must be positive");
_ok = true;
}

void BatchBuffer::cancel() {
std::lock_guard<std::mutex> lock(_buffer_mutex);
_ok = false;
std::for_each(_results_buffer.begin(), _results_buffer.end(), [](auto& thread_results) { thread_results.clear(); });
_buffer_cv.notify_all();
}

void BatchBuffer::addResults(const BatchResults& results, const size_t thread_id) {
STORM_LOG_THROW(thread_id < _results_buffer.size(), storm::exceptions::IllegalArgumentException, "Thread id out of bounds");
std::lock_guard<std::mutex> lock(_buffer_mutex);
STORM_LOG_THROW(_results_buffer[thread_id].size() < _n_slots, storm::exceptions::IllegalArgumentException, "Buffer is full");
_results_buffer.at(thread_id).emplace_back(results);
if (_ok) {
STORM_LOG_THROW(_results_buffer[thread_id].size() < _n_slots, storm::exceptions::IllegalArgumentException, "Buffer is full");
_results_buffer.at(thread_id).emplace_back(results);
}
}

std::optional<std::vector<BatchResults>> BatchBuffer::getResults() {
std::vector<BatchResults> results;
{
std::lock_guard<std::mutex> lock(_buffer_mutex);
if (!_ok) {
return std::nullopt;
}
if (std::any_of(
_results_buffer.begin(), _results_buffer.end(), [](const auto& thread_results) { return thread_results.empty(); })) {
return std::nullopt;
Expand All @@ -55,7 +68,7 @@ std::optional<std::vector<BatchResults>> BatchBuffer::getResults() {
void BatchBuffer::waitForSlotAvailable(const size_t thread_id) const {
STORM_LOG_THROW(thread_id < _results_buffer.size(), storm::exceptions::IllegalArgumentException, "Thread id out of bounds");
std::unique_lock<std::mutex> lock(_buffer_mutex);
_buffer_cv.wait(lock, [this, thread_id] { return _results_buffer[thread_id].size() < _n_slots; });
_buffer_cv.wait(lock, [this, thread_id] { return !_ok || _results_buffer[thread_id].size() < _n_slots; });
}

} // namespace smc_storm::samples
26 changes: 20 additions & 6 deletions src/samples/sampling_results.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ SamplingResults::SamplingResults(const settings::SmcSettings& settings, const st
: _settings{settings},
_results_buffer(settings.n_threads, 6U), _property_type{prop}, _quantile{calculateQuantile(_settings.confidence)}, _min_iterations{
50U} {
_keep_sampling = true;
_n_verified = 0U;
_n_not_verified = 0U;
_n_no_info = 0U;
Expand Down Expand Up @@ -200,6 +201,10 @@ void SamplingResults::addBatchResults(const BatchResults& res, const size_t thre
for (const auto& thread_res : *all_threads_res) {
processBatchResults(thread_res);
}
updateSamplingStatus();
if (!_keep_sampling) {
_results_buffer.cancel();
}
}
}

Expand Down Expand Up @@ -257,26 +262,35 @@ double SamplingResults::getProbabilityVerifiedProperty() const {

bool SamplingResults::newBatchNeeded(const size_t thread_id) const {
// Check if the buffer for the thread is full. Wait for a slot to be available in case
_results_buffer.waitForSlotAvailable(thread_id);
std::scoped_lock<std::mutex> lock(_mtx);
if (_keep_sampling) {
_results_buffer.waitForSlotAvailable(thread_id);
}
// The result might be available in the meanwhile, so use the _keep_sampling as return
return _keep_sampling;
}

void SamplingResults::updateSamplingStatus() {
// Reward properties require always reaching the target state
if (_property_type == state_properties::PropertyType::R && (_n_no_info > 0U || _n_not_verified > 0U)) {
return false;
_keep_sampling = false;
return;
}
const size_t n_samples = _n_no_info + _n_not_verified + _n_verified;
if (_settings.max_n_traces > 0U && n_samples >= _settings.max_n_traces) {
return false;
_keep_sampling = false;
return;
}
if (!minIterationsReached()) {
return true;
_keep_sampling = true;
return;
}
// Check if we never reached a terminal states
if (n_samples > _min_iterations && _n_no_info > n_samples * 0.5) {
STORM_LOG_THROW(
false, storm::exceptions::UnexpectedException,
"More than half the generated traces do not reach the terminal state. Aborting.");
}
return _bound_function();
_keep_sampling = _bound_function();
}

double SamplingResults::calculateQuantile(const double& confidence) {
Expand Down

0 comments on commit db115e5

Please sign in to comment.