Skip to content

Commit

Permalink
Store sequence ID in timestamps tuple object (#367)
Browse files Browse the repository at this point in the history
* Store sequence ID in timestamps tuple object

* Fix bug and address feedback

* Address feedback

* Address feedback
  • Loading branch information
matthewkotila committed Jul 28, 2023
1 parent 84260d0 commit edc03b3
Show file tree
Hide file tree
Showing 18 changed files with 390 additions and 202 deletions.
1 change: 1 addition & 0 deletions src/c++/perf_analyzer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ set(
concurrency_ctx_id_tracker.h
fifo_ctx_id_tracker.h
rand_ctx_id_tracker.h
request_record.h
)

add_executable(
Expand Down
45 changes: 29 additions & 16 deletions src/c++/perf_analyzer/client_backend/mock_client_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class MockInferInput : public InferInput {
///
class MockInferResult : public InferResult {
public:
MockInferResult(const InferOptions& options) : req_id_{options.request_id_} {}
MockInferResult(const InferOptions& options) : req_id_(options.request_id_) {}

Error Id(std::string* id) const override
{
Expand Down Expand Up @@ -468,14 +468,36 @@ class MockClientStats {

/// Mock implementation of ClientBackend interface
///
class MockClientBackend : public ClientBackend {
class NaggyMockClientBackend : public ClientBackend {
public:
MockClientBackend(std::shared_ptr<MockClientStats> stats) { stats_ = stats; }
NaggyMockClientBackend(std::shared_ptr<MockClientStats> stats) : stats_(stats)
{
ON_CALL(*this, AsyncStreamInfer(testing::_, testing::_, testing::_))
.WillByDefault(
[this](
const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
-> Error {
stats_->CaptureRequest(
MockClientStats::ReqType::ASYNC_STREAM, options, inputs,
outputs);

LaunchAsyncMockRequest(options, stream_callback_);

return stats_->GetNextReturnStatus();
});
}

MOCK_METHOD(
Error, ModelConfig,
(rapidjson::Document*, const std::string&, const std::string&),
(override));
MOCK_METHOD(
Error, AsyncStreamInfer,
(const InferOptions&, const std::vector<InferInput*>&,
const std::vector<const InferRequestedOutput*>&),
(override));

Error Infer(
InferResult** result, const InferOptions& options,
Expand Down Expand Up @@ -506,18 +528,6 @@ class MockClientBackend : public ClientBackend {
return stats_->GetNextReturnStatus();
}

Error AsyncStreamInfer(
const InferOptions& options, const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
{
stats_->CaptureRequest(
MockClientStats::ReqType::ASYNC_STREAM, options, inputs, outputs);

LaunchAsyncMockRequest(options, stream_callback_);

return stats_->GetNextReturnStatus();
}

Error StartStream(OnCompleteFn callback, bool enable_stats)
{
stats_->CaptureStreamStart();
Expand Down Expand Up @@ -601,6 +611,8 @@ class MockClientBackend : public ClientBackend {
return Error::Success;
}

OnCompleteFn stream_callback_;

private:
void LaunchAsyncMockRequest(const InferOptions options, OnCompleteFn callback)
{
Expand All @@ -619,9 +631,10 @@ class MockClientBackend : public ClientBackend {
size_t local_completed_req_count_ = 0;

std::shared_ptr<MockClientStats> stats_;
OnCompleteFn stream_callback_;
};

using MockClientBackend = testing::NiceMock<NaggyMockClientBackend>;

/// Mock factory that always creates a MockClientBackend instead
/// of a real backend
///
Expand Down
38 changes: 22 additions & 16 deletions src/c++/perf_analyzer/infer_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ InferContext::SendSequenceInferRequest(uint32_t seq_stat_index, bool delayed)

sequence_manager_->DecrementRemainingQueries(seq_stat_index);

SendRequest(request_id_++, delayed);
SendRequest(
request_id_++, delayed,
sequence_manager_->GetSequenceID(seq_stat_index));
}
}

Expand All @@ -95,12 +97,15 @@ InferContext::CompleteOngoingSequence(uint32_t seq_stat_index)
sequence_manager_->DecrementRemainingQueries(seq_stat_index);

bool is_delayed = false;
SendRequest(request_id_++, is_delayed);
SendRequest(
request_id_++, is_delayed,
sequence_manager_->GetSequenceID(seq_stat_index));
}
}

void
InferContext::SendRequest(const uint64_t request_id, const bool delayed)
InferContext::SendRequest(
const uint64_t request_id, const bool delayed, const uint64_t sequence_id)
{
if (!thread_stat_->status_.IsOk()) {
return;
Expand All @@ -111,14 +116,13 @@ InferContext::SendRequest(const uint64_t request_id, const bool delayed)
infer_data_.options_->request_id_ = std::to_string(request_id);
{
std::lock_guard<std::mutex> lock(thread_stat_->mu_);
auto it =
async_req_map_
.emplace(
infer_data_.options_->request_id_, AsyncRequestProperties())
.first;
auto it = async_req_map_
.emplace(infer_data_.options_->request_id_, RequestRecord())
.first;
it->second.start_time_ = std::chrono::system_clock::now();
it->second.sequence_end_ = infer_data_.options_->sequence_end_;
it->second.delayed_ = delayed;
it->second.sequence_id_ = sequence_id;
}

thread_stat_->idle_timer.Start();
Expand Down Expand Up @@ -157,13 +161,13 @@ InferContext::SendRequest(const uint64_t request_id, const bool delayed)
std::vector<std::chrono::time_point<std::chrono::system_clock>>
end_time_syncs{end_time_sync};
{
// Add the request timestamp to thread Timestamp vector with proper
// Add the request record to thread request records vector with proper
// locking
std::lock_guard<std::mutex> lock(thread_stat_->mu_);
auto total = end_time_sync - start_time_sync;
thread_stat_->request_timestamps_.emplace_back(std::make_tuple(
thread_stat_->request_records_.emplace_back(RequestRecord(
start_time_sync, std::move(end_time_syncs),
infer_data_.options_->sequence_end_, delayed));
infer_data_.options_->sequence_end_, delayed, sequence_id));
thread_stat_->status_ =
infer_backend_->ClientInferStat(&(thread_stat_->contexts_stat_[id_]));
if (!thread_stat_->status_.IsOk()) {
Expand Down Expand Up @@ -238,7 +242,7 @@ InferContext::AsyncCallbackFuncImpl(cb::InferResult* result)
std::shared_ptr<cb::InferResult> result_ptr(result);
bool is_final_response{true};
if (thread_stat_->cb_status_.IsOk()) {
// Add the request timestamp to thread Timestamp vector with
// Add the request record to thread request records vector with
// proper locking
std::lock_guard<std::mutex> lock(thread_stat_->mu_);
thread_stat_->cb_status_ = result_ptr->RequestStatus();
Expand All @@ -254,17 +258,19 @@ InferContext::AsyncCallbackFuncImpl(cb::InferResult* result)
return;
}
if (is_null_response == false) {
it->second.end_times.push_back(std::chrono::system_clock::now());
it->second.response_times_.push_back(
std::chrono::system_clock::now());
}
thread_stat_->cb_status_ =
result_ptr->IsFinalResponse(&is_final_response);
if (thread_stat_->cb_status_.IsOk() == false) {
return;
}
if (is_final_response) {
thread_stat_->request_timestamps_.emplace_back(std::make_tuple(
it->second.start_time_, it->second.end_times,
it->second.sequence_end_, it->second.delayed_));
thread_stat_->request_records_.emplace_back(
it->second.start_time_, it->second.response_times_,
it->second.sequence_end_, it->second.delayed_,
it->second.sequence_id_);
infer_backend_->ClientInferStat(&(thread_stat_->contexts_stat_[id_]));
thread_stat_->cb_status_ = ValidateOutputs(result);
async_req_map_.erase(request_id);
Expand Down
36 changes: 13 additions & 23 deletions src/c++/perf_analyzer/infer_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "iinfer_data_manager.h"
#include "infer_data.h"
#include "perf_utils.h"
#include "request_record.h"
#include "sequence_manager.h"

namespace triton { namespace perfanalyzer {
Expand All @@ -55,31 +56,16 @@ struct ThreadStat {
// Tracks the amount of time this thread spent sleeping or waiting
IdleTimer idle_timer;

// A vector of request timestamps <start_time, end_time>
// Request latency will be end_time - start_time
TimestampVector request_timestamps_;
// A vector of request records
std::vector<RequestRecord> request_records_;
// A lock to protect thread data
std::mutex mu_;
// The number of sent requests by this thread.
std::atomic<size_t> num_sent_requests_{0};
};

/// The properties of an asynchronous request required in
/// the callback to effectively interpret the response.
struct AsyncRequestProperties {
AsyncRequestProperties() : sequence_end_(false), delayed_(true) {}
// The timestamp of when the request was started.
std::chrono::time_point<std::chrono::system_clock> start_time_;
// Whether or not the request is at the end of a sequence.
bool sequence_end_;
// Whether or not the request is delayed as per schedule.
bool delayed_;
// Collection of response times
std::vector<std::chrono::time_point<std::chrono::system_clock>> end_times;
};

#ifndef DOCTEST_CONFIG_DISABLE
class MockInferContext;
class NaggyMockInferContext;
#endif

/// Sends inference requests to the server
Expand Down Expand Up @@ -146,7 +132,11 @@ class InferContext {
/// A helper function to issue inference request to the server.
/// \param request_id The unique id to be associated with the request.
/// \param delayed Whether the request fell behind its scheduled time.
virtual void SendRequest(const uint64_t request_id, const bool delayed);
/// \param sequence_id Sequence ID of the request. Note that the default of
/// `0` means the request is not a sequence.
virtual void SendRequest(
const uint64_t request_id, const bool delayed,
const uint64_t sequence_id = 0);

/// Update inputs based on custom json data
void UpdateJsonData();
Expand All @@ -159,8 +149,8 @@ class InferContext {
// Callback function for handling asynchronous requests
void AsyncCallbackFuncImpl(cb::InferResult* result);

const bool async_{false};
const bool streaming_{false};
bool async_{false};
bool streaming_{false};
const bool on_sequence_model_{false};
bool using_json_data_{false};
const int32_t batch_size_{0};
Expand All @@ -172,7 +162,7 @@ class InferContext {
std::shared_ptr<IInferDataManager> infer_data_manager_;

uint64_t request_id_ = 0;
std::map<std::string, AsyncRequestProperties> async_req_map_;
std::map<std::string, RequestRecord> async_req_map_;
std::atomic<uint> total_ongoing_requests_{0};
size_t data_step_id_;

Expand Down Expand Up @@ -203,7 +193,7 @@ class InferContext {
std::shared_ptr<SequenceManager> sequence_manager_{nullptr};

#ifndef DOCTEST_CONFIG_DISABLE
friend MockInferContext;
friend NaggyMockInferContext;

public:
InferContext() = default;
Expand Down
40 changes: 20 additions & 20 deletions src/c++/perf_analyzer/inference_profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -681,13 +681,13 @@ InferenceProfiler::ProfileHelper(
size_t completed_trials = 0;
std::queue<cb::Error> error;
std::deque<PerfStatus> measurement_perf_statuses;
all_timestamps_.clear();
all_request_records_.clear();
previous_window_end_ns_ = 0;

// Start with a fresh empty timestamp vector in the manager
// Start with a fresh empty request records vector in the manager
//
TimestampVector empty_timestamps;
RETURN_IF_ERROR(manager_->SwapTimestamps(empty_timestamps));
std::vector<RequestRecord> empty_request_records;
RETURN_IF_ERROR(manager_->SwapRequestRecords(empty_request_records));

do {
PerfStatus measurement_perf_status;
Expand Down Expand Up @@ -1193,11 +1193,11 @@ InferenceProfiler::Measure(
RETURN_IF_ERROR(manager_->GetAccumulatedClientStat(&end_stat));
prev_client_side_stats_ = end_stat;

TimestampVector current_timestamps;
RETURN_IF_ERROR(manager_->SwapTimestamps(current_timestamps));
all_timestamps_.insert(
all_timestamps_.end(), current_timestamps.begin(),
current_timestamps.end());
std::vector<RequestRecord> current_request_records;
RETURN_IF_ERROR(manager_->SwapRequestRecords(current_request_records));
all_request_records_.insert(
all_request_records_.end(), current_request_records.begin(),
current_request_records.end());

RETURN_IF_ERROR(Summarize(
start_status, end_status, start_stat, end_stat, perf_status,
Expand Down Expand Up @@ -1257,34 +1257,34 @@ InferenceProfiler::ValidLatencyMeasurement(
valid_sequence_count = 0;
response_count = 0;
std::vector<size_t> erase_indices{};
for (size_t i = 0; i < all_timestamps_.size(); i++) {
const auto& timestamp = all_timestamps_[i];
uint64_t request_start_ns = CHRONO_TO_NANOS(std::get<0>(timestamp));
uint64_t request_end_ns = CHRONO_TO_NANOS(std::get<1>(timestamp).back());
for (size_t i = 0; i < all_request_records_.size(); i++) {
const auto& request_record = all_request_records_[i];
uint64_t request_start_ns = CHRONO_TO_NANOS(request_record.start_time_);
uint64_t request_end_ns =
CHRONO_TO_NANOS(request_record.response_times_.back());

if (request_start_ns <= request_end_ns) {
// Only counting requests that end within the time interval
if ((request_end_ns >= valid_range.first) &&
(request_end_ns <= valid_range.second)) {
valid_latencies->push_back(request_end_ns - request_start_ns);
response_count += std::get<1>(timestamp).size();
response_count += request_record.response_times_.size();
erase_indices.push_back(i);
// Just add the sequence_end flag here.
if (std::get<2>(timestamp)) {
if (request_record.sequence_end_) {
valid_sequence_count++;
}
if (std::get<3>(timestamp)) {
if (request_record.delayed_) {
delayed_request_count++;
}
}
}
}

// Iterate through erase indices backwards so that erases from
// `all_timestamps_` happen from the back to the front to avoid using wrong
// indices after subsequent erases
// `all_request_records_` happen from the back to the front to avoid using
// wrong indices after subsequent erases
std::for_each(erase_indices.rbegin(), erase_indices.rend(), [this](size_t i) {
this->all_timestamps_.erase(this->all_timestamps_.begin() + i);
this->all_request_records_.erase(this->all_request_records_.begin() + i);
});

// Always sort measured latencies as percentile will be reported as default
Expand Down
10 changes: 5 additions & 5 deletions src/c++/perf_analyzer/inference_profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ cb::Error ReportPrometheusMetrics(const Metrics& metrics);
/// time.
/// 2. After given time interval, the profiler gets end status from the server
/// and records the end time.
/// 3. The profiler obtains the timestamps recorded by concurrency manager,
/// and uses the timestamps that are recorded between start time and end time
/// to measure client side status and update status_summary.
/// 3. The profiler obtains the request records recorded by concurrency manager,
/// and uses the request records that are recorded between start time and end
/// time to measure client side status and update status_summary.
///
class InferenceProfiler {
public:
Expand Down Expand Up @@ -678,8 +678,8 @@ class InferenceProfiler {
bool include_server_stats_;
std::shared_ptr<MPIDriver> mpi_driver_;

/// The timestamps of the requests completed during all measurements
TimestampVector all_timestamps_;
/// The request records of the requests completed during all measurements
std::vector<RequestRecord> all_request_records_;

/// The end time of the previous measurement window
uint64_t previous_window_end_ns_;
Expand Down
Loading

0 comments on commit edc03b3

Please sign in to comment.